12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882 |
- import contextlib
- import itertools
- import math
- import operator
- import weakref
- from enum import Enum
- from functools import partial, reduce
- from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union
- import torch
- import torch._prims_common as utils
- import torch.library
- from torch import sym_float, Tensor, TypedStorage
- from torch._C import _get_default_device
- from torch._prims.nvfuser_prims import register_nvprims
- from torch._prims_common import (
- check,
- Dim,
- DimsSequenceType,
- DimsType,
- IntLike,
- Number,
- NumberType,
- RETURN_TYPE,
- ShapeType,
- StrideType,
- TensorLike,
- TensorLikeType,
- type_to_dtype,
- )
- from torch._prims_common.wrappers import backwards_not_supported
- from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
- from torch.overrides import handle_torch_function, has_torch_function
- from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
- prim = torch.library.Library("prims", "DEF")
- prim_impl = torch.library.Library("prims", "IMPL", "CompositeExplicitAutograd")
- prim_backend_select_impl = torch.library.Library("prims", "IMPL", "BackendSelect")
- prim_autograd_impl = torch.library.Library("prims", "IMPL", "Autograd")
- prim_meta_impl = torch.library.Library("prims", "IMPL", "Meta")
- # Experimental module containing prototype "primitive" operations.
- __all__ = [
- #
- # Common datastructures and helpers
- #
- "RETURN_TYPE",
- #
- # Elementwise unary prims
- #
- "abs",
- "acos",
- "acosh",
- "asin",
- "asinh",
- "atan",
- "atanh",
- "cos",
- "cosh",
- "bessel_i0",
- "bessel_i0e",
- "bessel_i1",
- "bessel_i1e",
- "bessel_j0",
- "bessel_j1",
- "bitwise_not",
- "cbrt",
- "ceil",
- "conj_physical",
- "digamma",
- "erf",
- "erf_inv",
- "erfc",
- "erfcx",
- "exp",
- "expm1",
- "exp2",
- "fill",
- "floor",
- "imag",
- "isfinite",
- "lgamma",
- "log",
- "log1p",
- "log2",
- "log10",
- "ndtri",
- "neg",
- "real",
- "reciprocal",
- "round",
- "sign",
- "signbit",
- "sin",
- "sinh",
- "spherical_bessel_j0",
- "sqrt",
- "tan",
- "tanh",
- "trunc",
- #
- # Elementwise binary prims
- #
- "add",
- "atan2",
- "bitwise_and",
- "bitwise_or",
- "bitwise_xor",
- # 'complex', # needs custom meta
- "div",
- "eq",
- "fmax",
- "fmin",
- "fmod",
- "gcd",
- "ge",
- "gt",
- "hypot",
- "igamma",
- "igammac",
- "le",
- "lt",
- "maximum",
- "minimum",
- "mul",
- "ne",
- "nextafter",
- "pow",
- "remainder",
- "rsqrt",
- "shift_left",
- "shift_right_arithmetic",
- "shift_right_logical", # not implemented
- "sub",
- "zeta",
- #
- # View prims
- #
- "as_strided",
- "broadcast_in_dim",
- "collapse_view",
- "conj",
- "expand_dims",
- "slice",
- "slice_in_dim", # implemented using slice -- make this a ref?
- "split_dim",
- "squeeze",
- "transpose",
- "view_of",
- #
- # Functionalized view mutations
- #
- "as_strided_scatter",
- #
- # Shape prims
- #
- "collapse",
- "cat",
- "reshape",
- "rev",
- #
- # Conditional prims
- #
- "where",
- #
- # Data conversion and movement prims
- #
- "clone",
- "convert_element_type",
- "device_put",
- "item",
- "maximum_value",
- "minimum_value",
- "to_dtype",
- "copy_strided",
- #
- # Inplace prims
- #
- "copy_to",
- "resize",
- # "_set", # Commented out, see note below
- #
- # Reduction prims
- #
- "amax",
- "amin",
- "prod",
- "sum",
- "var",
- #
- # Tensor Creation Prims
- #
- "empty_strided",
- "scalar_tensor",
- "iota",
- #
- # Linear algebra (linalg) Prims
- #
- "svd",
- #
- # Randomness Prims
- #
- "normal",
- "_uniform_helper",
- #
- # FFT prims
- #
- "fft_r2c",
- "fft_c2c",
- "fft_c2r",
- ]
- def TensorMeta(
- tensorlike: Optional[Union[NumberType, torch.Tensor]] = None,
- *,
- shape: Optional[ShapeType] = None,
- strides: Optional[StrideType] = None,
- dtype: Optional[torch.dtype] = None,
- device: Optional[Union[torch.device, str]] = None,
- ):
- if isinstance(tensorlike, Number):
- assert not shape and (shape is None or isinstance(shape, Sequence))
- assert not strides and (strides is None or isinstance(strides, Sequence))
- inferred_shape: Tuple[int, ...] = ()
- inferred_strides: Tuple[int, ...] = ()
- inferred_dtype = type_to_dtype(type(tensorlike))
- inferred_device = torch.device("cpu")
- # TODO: This looks wrong, a number that is wrapped into a tensor
- # needs to behave differently than a scalar tensor for type
- # promotion purposes
- elif tensorlike is not None:
- assert isinstance(tensorlike, torch.Tensor)
- inferred_shape = tuple(tensorlike.shape)
- inferred_strides = tuple(tensorlike.stride())
- inferred_dtype = tensorlike.dtype
- inferred_device = tensorlike.device
- else:
- # If no tensorlike "example" is given then all metadata
- # must be provided explicitly
- assert shape is not None
- assert strides is not None
- assert dtype is not None
- assert device is not None
- shape = inferred_shape if shape is None else tuple(shape)
- strides = inferred_strides if strides is None else tuple(strides)
- dtype = inferred_dtype if dtype is None else dtype
- device = inferred_device if device is None else device
- if isinstance(device, str):
- device = torch.device(device)
- return torch.empty_strided(shape, strides, dtype=dtype, device=device)
- def _make_prim(
- *,
- schema: str,
- return_type: Union[RETURN_TYPE, Tuple[RETURN_TYPE, ...]],
- meta: Callable,
- impl_aten: Callable,
- doc: str,
- ):
- """
- Creates a primitive operation.
- """
- prim.define(schema)
- def _prim_impl(*args, **kwargs):
- # always run the meta function because aten implementation will
- # typically accept more inputs (e.g., it will do promotion and
- # broadcasting) which we want to reject
- meta(*args, **kwargs)
- return impl_aten(*args, **kwargs)
- # Right now prims don't support autograd (we can and should add an
- # argument that provides an implementation for backward here.) Because we
- # don't have derivative formulas, we must setup a custom autograd function
- # that raises an error if backwards is invoked
- def _autograd_impl(*args, **kwargs):
- return backwards_not_supported(_prim)(*args, **kwargs)
- def _backend_select_impl(*args, **kwargs):
- if kwargs.get("device") and kwargs["device"].type == "meta":
- return meta(*args, **kwargs)
- else:
- return _prim_impl(*args, **kwargs)
- name = schema.split("(")[0]
- prim_impl.impl(name, _prim_impl)
- prim_autograd_impl.impl(name, _autograd_impl)
- prim_meta_impl.impl(name, meta)
- _prim_packet = getattr(torch._ops.ops.prims, name)
- _prim = _prim_packet.default
- from torch._subclasses.fake_tensor import contains_tensor_types
- if not any(contains_tensor_types(a.type) for a in _prim._schema.arguments):
- prim_backend_select_impl.impl(name, _backend_select_impl)
- for p in (_prim_packet, _prim):
- p.__doc__ = doc
- p.return_type = return_type # type: ignore[attr-defined]
- p.schema = schema
- p.prim_impl = _prim_impl
- p.prim_meta_impl = meta
- p.impl_aten = impl_aten
- return _prim
- class ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND(Enum):
- DEFAULT = (0,)
- ALWAYS_BOOL = (2,)
- COMPLEX_TO_FLOAT = (3,)
- # TODO: implement dtype validation here, too, or on the corresponding refs
- def _elementwise_meta(
- *args,
- type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND,
- args_with_fixed_dtypes: Tuple[TensorLikeType, ...] = None,
- ) -> FakeTensor:
- """
- Meta function for elementwise operations that produce outputs in the same dtype
- as their inputs.
- Stride logic is currently incorrect.
- """
- assert len(args) > 0
- utils.check_same_dtype(*args)
- args_ = list(args)
- if args_with_fixed_dtypes is not None:
- args_ = list(args_with_fixed_dtypes) + args_
- utils.check_same_device(*args_, allow_cpu_scalar_tensors=True)
- utils.check_same_shape(*args_, allow_cpu_scalar_tensors=True)
- strides = utils.compute_elementwise_output_strides(*args_)
- shape = utils.extract_shape(*args_, allow_cpu_scalar_tensors=True)
- # Acquires the dtype
- dtype = None
- scalar_type = None
- for arg in args:
- if isinstance(arg, TensorLike):
- if not utils.is_cpu_scalar_tensor(arg):
- dtype = arg.dtype
- break
- else:
- dtype = arg.dtype
- elif isinstance(arg, Number):
- scalar_type = type(arg)
- if dtype is None and scalar_type is not None:
- dtype = utils.type_to_dtype(scalar_type)
- # Acquires the device (if it exists) or number
- device = None
- number = None
- for arg in args_:
- if isinstance(arg, TensorLike):
- if utils.is_cpu_scalar_tensor(arg):
- if device is None:
- device = arg.device
- # keep going, in case there is a cuda tensor later
- else:
- device = arg.device
- break
- elif isinstance(arg, Number):
- if number is None:
- number = arg
- # NOTE: type promotion behavior here is mostly hidden from tests because
- # references will typically handle the type promotion properly even if this doesn't
- # (but getting it wrong will cause too many casts to be inserted in traces!)
- if device is not None:
- assert dtype is not None
- if type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT:
- dtype = dtype
- elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
- dtype = torch.bool
- elif type_promotion == ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
- if utils.is_complex_dtype(dtype):
- dtype = utils.corresponding_real_dtype(dtype)
- else:
- dtype = dtype
- return TensorMeta(device=device, shape=shape, strides=strides, dtype=dtype)
- # Number case
- # TODO: fix number type promotion (bool, complex->float)
- # For now for symint/float, just implementing the common / simple cases of (int,float,symint,symfloat)
- seen_float = False
- if isinstance(number, (torch.SymInt, torch.SymFloat)):
- for a in args:
- assert isinstance(a, (int, float, torch.SymInt, torch.SymFloat)), "NYI"
- seen_float = seen_float or isinstance(a, (float, torch.SymFloat))
- if seen_float:
- number = sym_float(number)
- return TensorMeta(number) # type: ignore[arg-type]
- def _complex_only_elementwise_meta(*args, **kwargs):
- utils.check(
- utils.is_complex_dtype(args[0].dtype), lambda: "Only complex dtype is supported"
- )
- return _elementwise_meta(*args, **kwargs)
- def _make_elementwise_unary_prim(
- name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
- ):
- """
- Creates an elementwise unary prim.
- """
- return _make_prim(
- schema=f"{name}(Tensor self) -> Tensor",
- meta=partial(_elementwise_meta, type_promotion=type_promotion),
- return_type=RETURN_TYPE.NEW,
- **kwargs,
- )
- def _make_elementwise_binary_prim(
- name: str, *, type_promotion: ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND, **kwargs
- ):
- """
- Creates an elementwise binary prim.
- """
- return _make_prim(
- schema=f"{name}(Tensor self, Tensor other) -> Tensor",
- meta=partial(_elementwise_meta, type_promotion=type_promotion),
- return_type=RETURN_TYPE.NEW,
- **kwargs,
- )
- def _not_impl(*args, **kwargs):
- raise NotImplementedError
- #
- # Elementwise unary operations
- #
- abs = _make_elementwise_unary_prim(
- "abs",
- impl_aten=torch.abs,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
- )
- acos = _make_elementwise_unary_prim(
- "acos",
- impl_aten=torch.acos,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- acosh = _make_elementwise_unary_prim(
- "acosh",
- impl_aten=torch.acosh,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- asin = _make_elementwise_unary_prim(
- "asin",
- impl_aten=torch.asin,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- asinh = _make_elementwise_unary_prim(
- "asinh",
- impl_aten=torch.asinh,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- atan = _make_elementwise_unary_prim(
- "atan",
- impl_aten=torch.atan,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- atanh = _make_elementwise_unary_prim(
- "atanh",
- impl_aten=torch.atanh,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- cos = _make_elementwise_unary_prim(
- "cos",
- impl_aten=torch.cos,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- cosh = _make_elementwise_unary_prim(
- "cosh",
- impl_aten=torch.cosh,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- bessel_j0 = _make_elementwise_unary_prim(
- "bessel_j0",
- impl_aten=torch.special.bessel_j0,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- bessel_j1 = _make_elementwise_unary_prim(
- "bessel_j1",
- impl_aten=torch.special.bessel_j1,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- bessel_i0 = _make_elementwise_unary_prim(
- "bessel_i0",
- impl_aten=torch.i0,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- bessel_i0e = _make_elementwise_unary_prim(
- "bessel_i0e",
- impl_aten=torch.special.i0e,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- bessel_i1 = _make_elementwise_unary_prim(
- "bessel_i1",
- impl_aten=torch.special.i1,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- bessel_i1e = _make_elementwise_unary_prim(
- "bessel_i1e",
- impl_aten=torch.special.i1e,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- bitwise_not = _make_elementwise_unary_prim(
- "bitwise_not",
- impl_aten=torch.bitwise_not,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def _cbrt_aten(a: torch.Tensor) -> Tensor:
- utils.check(
- not a.is_complex(),
- lambda: "cbrt: Complex inputs not supported. Consider calling torch.pow(a, 1.0/3.0)",
- )
- # Returns the real cubic root of the number.
- # Note that if a < 0, pow(a, (1. / 3.)) returns th complex number
- # exp(1/3 * log(a)) = exp(1/3 * (log(abs(a)) + pi*i)) = cbrt(abs(a)) * e^{pi/3*i}
- # which is a complex number.
- # For more info see the section Note in
- # https://en.cppreference.com/w/cpp/numeric/math/cbrt
- return torch.copysign(torch.pow(a.abs(), 1 / 3), a)
- cbrt = _make_elementwise_unary_prim(
- "cbrt",
- impl_aten=_cbrt_aten,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- ceil = _make_elementwise_unary_prim(
- "ceil",
- impl_aten=torch.ceil,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def _conj_physical_meta(input: TensorLikeType) -> TensorLikeType:
- if not input.dtype.is_complex:
- raise RuntimeError("prims.conj_physical is only defined for complex dtypes")
- strides = utils.compute_elementwise_output_strides(input)
- return TensorMeta(input, strides=strides)
- conj_physical = _make_prim(
- schema="conj_physical(Tensor self) -> Tensor",
- meta=_conj_physical_meta,
- impl_aten=torch._conj_physical,
- doc="Returns the physical conjugation of a complex tensor",
- return_type=RETURN_TYPE.NEW,
- )
- def _clone_meta(
- input: TensorLikeType, *, memory_format: torch.memory_format = torch.preserve_format
- ) -> TensorLikeType:
- if memory_format != torch.preserve_format:
- return torch.empty(
- input.shape,
- dtype=input.dtype,
- layout=input.layout,
- device=input.device,
- requires_grad=input.requires_grad,
- memory_format=memory_format,
- )
- # memory_format == torch.preserve_format
- strides = utils.compute_elementwise_output_strides(input)
- return torch.empty_strided(
- input.shape,
- strides,
- dtype=input.dtype,
- layout=input.layout,
- device=input.device,
- requires_grad=input.requires_grad,
- )
- clone = _make_prim(
- schema="clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor",
- meta=_clone_meta,
- impl_aten=torch.clone,
- doc="Returns the copy of a tensor",
- return_type=RETURN_TYPE.NEW,
- )
- digamma = _make_elementwise_unary_prim(
- "digamma",
- impl_aten=torch.digamma,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- erf = _make_elementwise_unary_prim(
- "erf",
- impl_aten=torch.erf,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- erf_inv = _make_elementwise_unary_prim(
- "erf_inv",
- impl_aten=torch.special.erfinv,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- erfc = _make_elementwise_unary_prim(
- "erfc",
- impl_aten=torch.special.erfc,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- erfcx = _make_elementwise_unary_prim(
- "erfcx",
- impl_aten=torch.special.erfcx,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- exp = _make_elementwise_unary_prim(
- "exp",
- impl_aten=torch.exp,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- expm1 = _make_elementwise_unary_prim(
- "expm1",
- impl_aten=torch.special.expm1,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- exp2 = _make_elementwise_unary_prim(
- "exp2",
- impl_aten=torch.special.exp2,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def _fill_meta(a: TensorLikeType, value: NumberType) -> TensorLikeType:
- return _elementwise_meta(
- a, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
- )
- # NOTE: fill uses _make_prim directly because it has a value parameter
- fill = _make_prim(
- schema="fill(Tensor self, Scalar value) -> Tensor",
- return_type=RETURN_TYPE.NEW,
- meta=_fill_meta,
- impl_aten=torch.fill,
- doc="",
- )
- floor = _make_elementwise_unary_prim(
- "floor",
- impl_aten=torch.floor,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- imag = _make_prim(
- schema="imag(Tensor self) -> Tensor",
- meta=partial(
- _complex_only_elementwise_meta,
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
- ),
- return_type=RETURN_TYPE.VIEW,
- impl_aten=torch.imag,
- doc="",
- )
- isfinite = _make_elementwise_unary_prim(
- "isfinite",
- impl_aten=torch.isfinite,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- lgamma = _make_elementwise_unary_prim(
- "lgamma",
- impl_aten=torch.lgamma,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- log = _make_elementwise_unary_prim(
- "log",
- impl_aten=torch.log,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- log1p = _make_elementwise_unary_prim(
- "log1p",
- impl_aten=torch.log1p,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- log2 = _make_elementwise_unary_prim(
- "log2",
- impl_aten=torch.log2,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- log10 = _make_elementwise_unary_prim(
- "log10",
- impl_aten=torch.log10,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- real = _make_prim(
- schema="real(Tensor self) -> Tensor",
- meta=partial(
- _complex_only_elementwise_meta,
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
- ),
- return_type=RETURN_TYPE.VIEW,
- impl_aten=torch.real,
- doc="",
- )
- reciprocal = _make_elementwise_unary_prim(
- "reciprocal",
- impl_aten=torch.reciprocal,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- ndtri = _make_elementwise_unary_prim(
- "ndtri",
- impl_aten=torch.special.ndtri,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- neg = _make_elementwise_unary_prim(
- "neg",
- impl_aten=torch.neg,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- round = _make_elementwise_unary_prim(
- "round",
- impl_aten=torch.round,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- rsqrt = _make_elementwise_unary_prim(
- "rsqrt",
- impl_aten=torch.rsqrt,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- sign = _make_elementwise_unary_prim(
- "sign",
- impl_aten=torch.sign,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- signbit = _make_elementwise_unary_prim(
- "signbit",
- impl_aten=torch.signbit,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- sin = _make_elementwise_unary_prim(
- "sin",
- impl_aten=torch.sin,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- sinh = _make_elementwise_unary_prim(
- "sinh",
- impl_aten=torch.sinh,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- spherical_bessel_j0 = _make_elementwise_unary_prim(
- "spherical_bessel_j0",
- impl_aten=torch.special.spherical_bessel_j0,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- sqrt = _make_elementwise_unary_prim(
- "sqrt",
- impl_aten=torch.sqrt,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- tan = _make_elementwise_unary_prim(
- "tan",
- impl_aten=torch.tan,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- tanh = _make_elementwise_unary_prim(
- "tanh",
- impl_aten=torch.tanh,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- trunc = _make_elementwise_unary_prim(
- "trunc",
- impl_aten=torch.trunc,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- #
- # Elementwise binary operations
- #
- add = _make_elementwise_binary_prim(
- name="add",
- impl_aten=torch.add,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- atan2 = _make_elementwise_binary_prim(
- name="atan2",
- impl_aten=torch.atan2,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- bitwise_and = _make_elementwise_binary_prim(
- "bitwise_and",
- impl_aten=torch.bitwise_and,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- bitwise_or = _make_elementwise_binary_prim(
- "bitwise_or",
- impl_aten=torch.bitwise_or,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- bitwise_xor = _make_elementwise_binary_prim(
- "bitwise_xor",
- impl_aten=torch.bitwise_xor,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- # TODO: complex needs a special meta to account for its float -> complex behavior
- # complex = _make_elementwise_binary_prim(
- # impl_aten=torch.complex,
- # doc="",
- # )
- # div prim performs truncation division on integer inputs
- # and true division for floating and complex inputs
- def _div_aten(a, b):
- is_integral = isinstance(a, (bool, int, torch.SymInt)) or (
- isinstance(a, torch.Tensor) and utils.is_integer_dtype(a.dtype)
- )
- if is_integral:
- return torch.div(a, b, rounding_mode="trunc")
- else:
- return torch.true_divide(a, b)
- div = _make_elementwise_binary_prim(
- "div",
- impl_aten=_div_aten,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- eq = _make_elementwise_binary_prim(
- "eq",
- impl_aten=torch.eq,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- fmax = _make_elementwise_binary_prim(
- "fmax",
- impl_aten=torch.fmax,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- fmin = _make_elementwise_binary_prim(
- "fmin",
- impl_aten=torch.fmin,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- fmod = _make_elementwise_binary_prim(
- "fmod",
- impl_aten=torch.fmod,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- gcd = _make_elementwise_binary_prim(
- "gcd",
- impl_aten=torch.gcd,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- ge = _make_elementwise_binary_prim(
- "ge",
- impl_aten=torch.ge,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- gt = _make_elementwise_binary_prim(
- "gt",
- impl_aten=torch.gt,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- hypot = _make_elementwise_binary_prim(
- "hypot",
- impl_aten=torch.hypot,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- igamma = _make_elementwise_binary_prim(
- "igamma",
- impl_aten=torch.special.gammainc,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- igammac = _make_elementwise_binary_prim(
- "igammac",
- impl_aten=torch.special.gammaincc,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- le = _make_elementwise_binary_prim(
- "le",
- impl_aten=torch.le,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- lt = _make_elementwise_binary_prim(
- "lt",
- impl_aten=torch.lt,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- # Note: the following impls are because torch.maximum and torch.mininum do not support scalar inputs
- def _maximum_aten(
- a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
- ) -> TensorLikeType:
- if isinstance(a, TensorLike) and isinstance(b, Number):
- b = scalar_tensor(b, dtype=a.dtype, device=a.device)
- elif isinstance(b, TensorLike) and isinstance(a, Number):
- a = scalar_tensor(a, dtype=b.dtype, device=b.device)
- return torch.maximum(a, b) # type: ignore[arg-type]
- maximum = _make_elementwise_binary_prim(
- "maximum",
- impl_aten=_maximum_aten,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- def _minimum_aten(
- a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
- ) -> TensorLikeType:
- if isinstance(a, TensorLike) and isinstance(b, Number):
- b = scalar_tensor(b, dtype=a.dtype, device=a.device)
- elif isinstance(b, TensorLike) and isinstance(a, Number):
- a = scalar_tensor(a, dtype=b.dtype, device=b.device)
- return torch.minimum(a, b) # type: ignore[arg-type]
- minimum = _make_elementwise_binary_prim(
- "minimum",
- impl_aten=_minimum_aten,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- mul = _make_elementwise_binary_prim(
- "mul",
- impl_aten=torch.mul,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- ne = _make_elementwise_binary_prim(
- "ne",
- impl_aten=torch.ne,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.ALWAYS_BOOL,
- )
- nextafter = _make_elementwise_binary_prim(
- "nextafter",
- impl_aten=torch.nextafter,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- pow = _make_elementwise_binary_prim(
- "pow",
- impl_aten=torch.pow,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- remainder = _make_elementwise_binary_prim(
- "remainder",
- impl_aten=torch.remainder,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- shift_left = _make_elementwise_binary_prim(
- "shift_left",
- impl_aten=torch.bitwise_left_shift,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- shift_right_arithmetic = _make_elementwise_binary_prim(
- "shift_right_arithmetic",
- impl_aten=torch.bitwise_right_shift,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- shift_right_logical = _not_impl
- sub = _make_elementwise_binary_prim(
- "sub",
- impl_aten=torch.sub,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- zeta = _make_elementwise_binary_prim(
- "zeta",
- impl_aten=torch.special.zeta,
- doc="",
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- )
- #
- # View operations
- def _as_strided_meta(
- a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int
- ) -> TensorLikeType:
- assert len(size) == len(stride)
- assert storage_offset >= 0
- utils.validate_strides(stride)
- utils.validate_shape(size)
- if reduce(operator.mul, size) == 0:
- # NOTE: This special case is to avoid having to acquire the storage below
- # as_strided to shapes with no elements are trivially valid, so it's OK
- pass
- elif isinstance(a, torch.Tensor):
- utils.check_in_bounds_for_storage(
- a._typed_storage(), size, stride, storage_offset
- )
- return torch.as_strided(a, size, stride, storage_offset)
- def _as_strided_aten(
- a: Tensor, size: ShapeType, stride: StrideType, storage_offset: int
- ) -> Tensor:
- return torch.as_strided(a, size, stride, storage_offset)
- _as_strided_doc = """
- Creates a view of the tensor with the given shape (size), strides (stride) and
- storage offset (storage_offset).
- """
- as_strided = _make_prim(
- schema="as_strided(Tensor(a!) a, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor(a!)",
- meta=_as_strided_meta,
- impl_aten=_as_strided_aten,
- return_type=RETURN_TYPE.VIEW,
- doc=_as_strided_doc,
- )
- def _broadcast_in_dim_meta(
- a: TensorLikeType, shape: ShapeType, broadcast_dimensions: Sequence[int]
- ):
- # Type checks
- assert isinstance(a, TensorLike)
- assert isinstance(shape, Sequence)
- assert isinstance(broadcast_dimensions, Sequence)
- # every dimension must be accounted for
- assert a.ndim == len(broadcast_dimensions)
- # broadcast shape must have weakly more dimensions
- assert len(shape) >= a.ndim
- # broadcast_dimensions must be an ascending sequence
- # (no relative reordering of dims) of integers and
- # each dimension must be within the new shape
- def _greater_than_reduce(acc, x):
- assert isinstance(x, Dim)
- assert x > acc
- assert x < len(shape)
- return x
- reduce(lambda acc, x: _greater_than_reduce(acc, x), broadcast_dimensions, -1)
- # shape must be broadcastable to
- for idx, new_idx in enumerate(broadcast_dimensions):
- assert a.shape[idx] == 1 or a.shape[idx] == shape[new_idx]
- new_strides = []
- original_idx = 0
- for idx in range(len(shape)):
- if idx in broadcast_dimensions:
- # Assigns a stride of zero to dimensions
- # which were actually broadcast
- if a.shape[original_idx] != shape[idx]:
- new_strides.append(0)
- else:
- new_strides.append(a.stride()[original_idx])
- original_idx = original_idx + 1
- else:
- if shape[idx] != 1:
- new_strides.append(0)
- elif original_idx == a.ndim:
- new_strides.append(1)
- else:
- new_strides.append(a.stride()[original_idx] * a.size()[original_idx])
- return a.as_strided(shape, new_strides, a.storage_offset())
- def _broadcast_in_dim_aten(a, shape, broadcast_dimensions):
- s = list(shape)
- for broadcast_dimension in broadcast_dimensions:
- s[broadcast_dimension] = -1
- v = a
- for idx, x in enumerate(s):
- if x != -1:
- v = v.unsqueeze(idx)
- return v.expand(shape)
- _broadcast_in_dim_doc = """
- Creates a view of a with the specified shape.
- Allows adding dimensions of any length and broadcasting
- dimensions of length one in a to any length.
- The location of the broadcast dimensions must be specified
- using the broadcast_dimensions argument. Changing the
- relative order of dimensions is not supported.
- """
- broadcast_in_dim = _make_prim(
- schema="broadcast_in_dim(Tensor(a) a, SymInt[] shape, int[] broadcast_dimensions) -> Tensor(a)",
- meta=_broadcast_in_dim_meta,
- impl_aten=_broadcast_in_dim_aten,
- return_type=RETURN_TYPE.VIEW,
- doc=_broadcast_in_dim_doc,
- )
- def _collapse_view_helper(
- a: TensorLikeType, start: int, end: int
- ) -> Tuple[Optional[ShapeType], Optional[StrideType]]:
- assert isinstance(a, TensorLike)
- # Special-case for zero dimensional tensors
- if a.ndim == 0:
- shape = (1,)
- strides = (1,)
- else:
- shape = a.shape # type: ignore[assignment]
- strides = a.stride() # type: ignore[assignment]
- utils.validate_idx(len(shape), start)
- utils.validate_exclusive_idx(len(shape), end)
- # Verifies end is strictly greater than start
- # (Collapse requires a non-empty interval)
- if end <= start:
- msg = "Attempting to collapse but end, {0}, is less than or equal to start, {1}!".format(
- end, start
- )
- raise ValueError(msg)
- if a.ndim == 0 or (end - 1 == start):
- return shape, strides
- length = shape[end - 1]
- stride = strides[end - 1]
- for idx in reversed(range(start, end - 1)):
- if shape[idx] == 0 or shape[idx + 1] == 0:
- length = 0
- stride = 0
- break
- if shape[idx] == 1:
- continue
- length = length * shape[idx]
- stride = min(stride, strides[idx])
- if (
- a.numel() > 0
- and shape[idx + 1] != 1
- and not (strides[idx] == strides[idx + 1] * shape[idx + 1])
- ):
- return None, None
- new_shape = shape[:start] + (length,) + shape[end:]
- new_strides = strides[:start] + (stride,) + strides[end:]
- # NOTE: when the input has no elements it's restrided as if it were contiguous
- if a.numel() == 0:
- new_strides = utils.make_contiguous_strides_for(new_shape)
- return new_shape, new_strides
- def _collapse_view_meta(a: TensorLikeType, start: int, end: int) -> TensorLikeType:
- new_shape, new_strides = _collapse_view_helper(a, start, end)
- if new_shape is None:
- msg = "Attempting to view a collapsed tensor, but no such view exists!"
- raise ValueError(msg)
- if new_strides is None:
- return a.view(new_shape)
- else:
- return a.as_strided(new_shape, new_strides, a.storage_offset())
- def _collapse_view_aten(a: Tensor, start: int, end: int) -> Tensor:
- # Special-cases zero-dim tensors
- if a.ndim == 0:
- shape = (1,)
- else:
- shape = a.shape # type: ignore[assignment]
- dim_length = 1
- for idx in range(start, end):
- dim_length = dim_length * shape[idx]
- new_shape = shape[0:start] + (dim_length,) + shape[end:]
- return a.view(new_shape)
- _collapse_view_doc = """
- Creates a view of a with the dimensions between
- start (inclusive) and end (exclusive) merged into a
- single dimension.
- If it's not possible to take such a view then an error
- is thrown. See collapse instead.
- The dimensions can be merged if and only if
- they are all "nested" with each other. That is, they all
- have the property that
- stride[i] = stride[i+1] * shape[i+1]
- for all i in [start, end - 1).
- """
- collapse_view = _make_prim(
- schema="collapse_view(Tensor(a) a, int start, int end) -> Tensor(a)",
- meta=_collapse_view_meta,
- impl_aten=_collapse_view_aten,
- return_type=RETURN_TYPE.VIEW,
- doc=_collapse_view_doc,
- )
- def _conj_meta(a: TensorLikeType) -> TensorLikeType:
- if not a.dtype.is_complex:
- raise RuntimeError("Expected complex dtype in prims.conj")
- return a.as_strided(a.shape, a.stride(), a.storage_offset())
- _conj_doc = """
- Returns a conjugated view of the original tensor
- """
- conj = _make_prim(
- schema="conj(Tensor(a) a) -> Tensor(a)",
- meta=_conj_meta,
- impl_aten=torch.conj,
- return_type=RETURN_TYPE.VIEW,
- doc=_conj_doc,
- )
- def expand_dims(
- a: TensorLikeType, dimensions: DimsSequenceType, ndim=None
- ) -> TensorLikeType:
- """
- Creates a view of a with a.ndim + len(dimensions) dimensions, with new
- dimensions of length one at the dimensions specified by dimensions.
- """
- if ndim is not None:
- # TODO: this is only here to support the unsqueeze ref
- dims = sorted(utils.canonicalize_dims(ndim, dimensions)) # type: ignore[arg-type]
- else:
- dims = sorted(utils.canonicalize_dims(a.ndim, dimensions)) # type: ignore[arg-type]
- if len(set(dims)) != len(dims):
- msg = "Received duplicate dimensions to expand in {0}".format(str(dimensions))
- raise ValueError(msg)
- new_shape = list(a.shape)
- for idx in dims:
- new_shape.insert(idx, 1)
- broadcast_dimensions = [
- idx for idx in range(len(new_shape)) if idx not in dimensions
- ]
- return broadcast_in_dim(a, new_shape, broadcast_dimensions)
- # Note: saves the Python slice object because we're about to clobber its name with the slice prim
- pyslice: Type[slice] = slice # type: ignore[has-type]
- def _slice_meta(
- a: TensorLikeType,
- start_indices: DimsSequenceType,
- limit_indices: DimsSequenceType,
- strides: Optional[StrideType] = None,
- ) -> TensorLikeType:
- _strides = strides if strides is not None else [1] * len(start_indices)
- if a.ndim != len(start_indices):
- msg = "Attempting to slice tensor of rank {0} with start_indices of length {1}!".format(
- a.ndim, len(start_indices)
- )
- raise ValueError(msg)
- if a.ndim != len(limit_indices):
- msg = "Attempting to slice tensor of rank {0} with limit_indices of length {1}!".format(
- a.ndim, len(limit_indices)
- )
- raise ValueError(msg)
- if a.ndim != len(_strides):
- msg = (
- "Attempting to slice tensor of rank {0} with strides of length {1}!".format(
- a.ndim, len(limit_indices)
- )
- )
- raise ValueError(msg)
- for x, y in zip(start_indices, a.shape):
- if x < 0:
- msg = "Attempting to slice a tensor with a negative start index of {0}!".format(
- x
- )
- raise ValueError(msg)
- if x > y:
- msg = (
- "Attempting to slice a tensor but a start index in {0} is greater than"
- " the length of its corresponding dimension in shape {1}".format(
- start_indices, a.shape
- )
- )
- raise ValueError(msg)
- for x, y, z in zip(limit_indices, a.shape, start_indices):
- if x < 0:
- msg = "Attempting to slice a tensor with a negative stop index of {0}!".format(
- x
- )
- raise ValueError(msg)
- if x > y:
- msg = (
- "Attempting to slice a tensor but a stop index in {0} is greater than the length of "
- " its corresponding dimension in shape {1}".format(
- limit_indices, a.shape
- )
- )
- raise ValueError(msg)
- if x < z:
- msg = (
- "Attempting to slice a tensor but a start index in {0} is greater than "
- " its corresponding stop index {1}".format(x, z)
- )
- for x in _strides:
- if x <= 0:
- msg = (
- "Attempting to slice a tensor with a non-positive step of {0}!".format(
- x
- )
- )
- raise ValueError(msg)
- new_shape = []
- for x, y, z in zip(start_indices, limit_indices, _strides):
- new_shape.append(math.floor((y - x) / z))
- new_strides = []
- for x, y in zip(a.stride(), _strides):
- new_strides.append(x * y)
- return a.as_strided(new_shape, new_strides, a.storage_offset())
- def _slice_aten(
- a: Tensor,
- start_indices: DimsSequenceType,
- limit_indices: DimsSequenceType,
- strides: Optional[StrideType] = None,
- ) -> Tensor:
- _strides = strides if strides is not None else [1] * len(start_indices)
- slices = []
- for start, stop, step in zip(start_indices, limit_indices, _strides):
- slices.append(pyslice(start, stop, step))
- return operator.getitem(a, slices) # type: ignore[call-overload]
- _slice_doc = """
- Creates a view of a "bounding box" within the tensor.
- The bounding box is specified independently in each of the tensor's dimensions.
- start_indices and limit_indices describe the box's boundaries for their corresponding
- dimensions. If strides is specified then they specify the step size between elements
- in their corresponding dimension.
- This operation is analogous to slicing in NumPy, but does not permit slices where
- the stop indices are less than the start indices.
- """
- slice = _make_prim(
- schema="slice(Tensor(a) a, SymInt[] start_indices, SymInt[] limit_indices, SymInt[]? strides=None) -> Tensor(a)",
- meta=_slice_meta,
- impl_aten=_slice_aten,
- return_type=RETURN_TYPE.VIEW,
- doc=_slice_doc,
- )
- def _slice_in_dim_meta(
- a: TensorLikeType,
- start_index: int,
- limit_index: int,
- stride: int = 1,
- axis: int = 0,
- ) -> TensorLikeType:
- if axis < 0:
- msg = "slice_in_dim: received a negative axis {0}".format(axis)
- raise ValueError(msg)
- if axis >= a.ndim:
- msg = "slice_in_dim: axis {0} is greater or equal to the rank {1} of the tensor".format(
- axis, a.ndim
- )
- raise ValueError(msg)
- if start_index < 0:
- msg = "slice_in_dim: received a negative start_index {0}".format(start_index)
- raise ValueError(msg)
- if start_index > a.shape[axis]:
- msg = "slice_in_dim: start_index is greater than the length {0} of dimension {1}".format(
- start_index, axis
- )
- raise ValueError(msg)
- if limit_index > a.shape[axis]:
- msg = "slice_in_dim: limit_index is greater than the length {0} of dimension {1}".format(
- limit_index, axis
- )
- raise ValueError(msg)
- if limit_index < start_index:
- msg = "slice_in_dim: received a limit_index {0} less than the start_index {1}".format(
- limit_index, start_index
- )
- raise ValueError(msg)
- if stride < 0:
- msg = "slice_in_dim: received a non-positive stride of {0}!".format(stride)
- raise ValueError(msg)
- start_indices = [0] * a.ndim
- limit_indices = list(a.shape)
- strides = [1] * a.ndim
- start_indices[axis] = start_index
- limit_indices[axis] = limit_index
- strides[axis] = stride
- return _slice_meta(a, start_indices, limit_indices, strides)
- def _slice_in_dim_aten(
- a: Tensor,
- start_index: int,
- limit_index: int,
- stride: int = 1,
- axis: int = 0,
- ) -> Tensor:
- start_indices = [0] * a.ndim
- limit_indices = list(a.shape)
- strides = [1] * a.ndim
- start_indices[axis] = start_index
- limit_indices[axis] = limit_index
- strides[axis] = stride
- return slice(a, start_indices, limit_indices, strides)
- _slice_in_dim_doc = """
- Convenience wrapper for slicing just one dimension using slice.
- """
- # TODO: make stride SymInt
- slice_in_dim = _make_prim(
- schema="slice_in_dim(Tensor(a) a, SymInt start_index, SymInt limit_index, int stride=1, int axis=0) -> Tensor(a)",
- meta=_slice_in_dim_meta,
- impl_aten=_slice_in_dim_aten,
- return_type=RETURN_TYPE.VIEW,
- doc=_slice_in_dim_doc,
- )
- def _split_dim_meta(a: TensorLikeType, dim: int, outer_length: int) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- utils.validate_idx(a.ndim, dim)
- utils.validate_dim_length(outer_length)
- # Verifies the dim can be split with the specified lhs_length
- inner_length = a.shape[dim] // outer_length
- if (a.shape[dim] % outer_length) != 0:
- msg = "Attempting to split dimension of length {0}, but outer length of {1} divides it with a remainder!".format(
- a.shape[dim], outer_length
- )
- raise ValueError(msg)
- new_shape: List[int] = []
- new_strides: List[int] = []
- for idx in range(a.ndim):
- if idx == dim:
- new_shape.extend((outer_length, inner_length))
- new_strides.extend((a.stride()[idx] * inner_length, a.stride()[idx]))
- else:
- new_shape.append(a.shape[idx])
- new_strides.append(a.stride()[idx])
- return a.as_strided(new_shape, new_strides, a.storage_offset())
- def _split_dim_aten(a: Tensor, dim: int, outer_length: int) -> Tensor:
- inner_length = a.shape[dim] // outer_length
- new_shape = a.shape[0:dim] + (outer_length, inner_length) + a.shape[dim + 1 :]
- return a.view(new_shape)
- _split_dim_doc = """
- Creates a view of a with the given dimension (of length l) split
- into two dimensions, with the outer of the two having
- length outer_length and the inner of the two having computed
- length inner_length such outer_length * inner_length = l.
- """
- # TODO: consider renaming split_dim_view
- split_dim = _make_prim(
- schema="split_dim(Tensor(a) a, int dim, SymInt outer_length) -> Tensor(a)",
- meta=_split_dim_meta,
- impl_aten=_split_dim_aten,
- return_type=RETURN_TYPE.VIEW,
- doc=_split_dim_doc,
- )
- # Note: allows dimensions to be specified redundantly
- def _squeeze_meta(a: TensorLikeType, dimensions: Sequence) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- for idx in dimensions:
- utils.validate_idx(a.ndim, idx)
- assert a.shape[idx] == 1
- new_shape = []
- new_strides = []
- for idx in range(len(a.shape)):
- if idx in dimensions:
- continue
- new_shape.append(a.shape[idx])
- new_strides.append(a.stride()[idx])
- return a.as_strided(new_shape, new_strides, a.storage_offset())
- _squeeze_doc = """
- Creates a view of the tensor with the specified dimensions removed.
- The removed dimensions must each have length one.
- """
- squeeze = _make_prim(
- schema="squeeze(Tensor(a) a, int[] dimensions) -> Tensor(a)",
- meta=_squeeze_meta,
- impl_aten=torch.squeeze,
- return_type=RETURN_TYPE.VIEW,
- doc=_squeeze_doc,
- )
- def _transpose_meta(a: TensorLikeType, permutation: DimsSequenceType) -> TensorLikeType:
- if a.ndim != len(permutation):
- msg = "Attempting to permute a tensor of rank {0}, but received a permutation of length {1}!".format(
- a.ndim, len(permutation)
- )
- raise ValueError(msg)
- if not utils.is_valid_permutation(a.ndim, permutation):
- msg = "Received an invalid permutation, {0}!".format(permutation)
- raise ValueError(msg)
- new_shape = [0] * a.ndim
- new_strides = [0] * a.ndim
- for idx, dim in enumerate(permutation):
- new_shape[idx] = a.shape[dim]
- new_strides[idx] = a.stride()[dim]
- return a.as_strided(tuple(new_shape), tuple(new_strides), a.storage_offset())
- def _transpose_aten(a: Tensor, permutation: DimsSequenceType) -> Tensor:
- return torch.permute(a, permutation)
- _transpose_doc = """
- Creates a view of the tensor with its dimensions permuted.
- The length of the permutation must be the rank of the tensor,
- and each element of the permutation specifies the new order
- for the corresponding dimension.
- """
- transpose = _make_prim(
- schema="transpose(Tensor(a) a, int[] permutation) -> Tensor(a)",
- meta=_transpose_meta,
- impl_aten=_transpose_aten,
- return_type=RETURN_TYPE.VIEW,
- doc=_transpose_doc,
- )
- def _view_of_meta(a: TensorLikeType) -> TensorLikeType:
- return a.as_strided(a.shape, a.stride(), a.storage_offset())
- def _view_of_aten(a: Tensor) -> Tensor:
- return a.view(a.shape)
- _view_of_doc = """
- Creates a view of the tensor.
- """
- view_of = _make_prim(
- schema="view_of(Tensor(a) a) -> Tensor",
- meta=_view_of_meta,
- impl_aten=_view_of_aten,
- return_type=RETURN_TYPE.VIEW,
- doc=_view_of_doc,
- )
- #
- # Functionalized view mutations
- #
- def _as_strided_scatter_meta(
- input: TensorLikeType,
- src: TensorLikeType,
- size: ShapeType,
- stride: StrideType,
- storage_offset: int,
- ) -> TensorLikeType:
- utils.validate_shape(size)
- utils.validate_strides(stride)
- required_size = utils.compute_required_storage_length(size, stride, storage_offset)
- utils.check(
- input.numel() >= required_size,
- lambda: (
- f"as_strided_scatter: sizes {size}, strides {stride}, storage offset {storage_offset} "
- f" and itemsize {input.element_size()} requiring a storage size of "
- f"{required_size * input.element_size()} are out of bounds "
- f"for storage of size {input.numel() * input.element_size()}"
- ),
- )
- utils.check(
- utils.is_same_shape(src.shape, size),
- lambda: f"expected src to have a size equal to the slice of self. src size = {src.shape}, slice size = {size}",
- )
- return utils.clone_preserve_strides(input)
- _as_strided_scatter_doc = """
- Creates a new tensor equivalent to ``out = input.clone()`` after mutation by
- ``out.as_strided(size, stride, storage_offset).copy_(src)``.
- """
- as_strided_scatter = _make_prim(
- schema="as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt storage_offset) -> Tensor",
- meta=_as_strided_scatter_meta,
- impl_aten=torch.as_strided_scatter,
- return_type=RETURN_TYPE.NEW,
- doc=_as_strided_scatter_doc,
- )
- #
- # Shape operations
- #
- def collapse(a: Tensor, start: int, end: int) -> Tensor:
- """
- Wrapper around reshape that collapses a span of dimensions.
- See collapse_view for the corresponding view operation.
- """
- dim_length = 1
- for idx in range(start, end):
- dim_length = dim_length * a.shape[idx]
- new_shape = a.shape[0:start] + (dim_length,) + a.shape[end:]
- return reshape(a, new_shape)
- # TODO: review stride logic
- def _cat_meta(tensors: Sequence[TensorLikeType], dim: int) -> TensorLikeType:
- # Verifies same shape (except in the concat dimension)
- shape = tensors[0].shape
- concat_length = 0
- for tensor_idx, tensor in enumerate(tensors):
- for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
- if idx == dim:
- concat_length = concat_length + length
- elif length != common_length:
- raise RuntimeError(
- f"Sizes of tensors must match except in dimension {dim}. "
- f"Expected {common_length} but got {length} for tensor number "
- f"{tensor_idx} in the list"
- )
- new_shape = list(tensors[0].shape).copy()
- new_shape[dim] = concat_length
- return TensorMeta(
- tensors[0],
- shape=new_shape,
- strides=utils.make_contiguous_strides_for(new_shape),
- )
- def _cat_aten(tensors: Union[Tuple[Tensor, ...], List[Tensor]], dim: int) -> Tensor:
- return torch.cat(tensors, dim)
- _cat_doc = """
- Concatenates tensors along the specified dimension.
- The tensors' shapes must have the same rank and same length for other dimensions.
- """
- cat = _make_prim(
- schema="cat(Tensor[] tensors, int dim) -> Tensor",
- meta=_cat_meta,
- impl_aten=_cat_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_cat_doc,
- )
- def _reshape_meta(a: TensorLikeType, shape: ShapeType):
- assert isinstance(a, TensorLike)
- utils.validate_shape(shape)
- # Validates the tensor and the requested shape have the
- # same number of elements
- numel = reduce(operator.mul, shape)
- if numel != a.numel():
- msg = "Attempting to reshape a tensor with {0} elements to a shape with {1} elements!".format(
- a.numel(), numel
- )
- raise ValueError(msg)
- return TensorMeta(a, shape=shape, strides=utils.make_contiguous_strides_for(shape))
- def _reshape_aten(a: Tensor, shape: ShapeType) -> Tensor:
- return a.reshape(shape).contiguous().clone()
- _reshape_doc = """
- Creates a contiguous tensor with the specified shape
- containing a copy of the data in a.
- """
- reshape = _make_prim(
- schema="reshape(Tensor a, SymInt[] shape) -> Tensor",
- meta=_reshape_meta,
- impl_aten=_reshape_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_reshape_doc,
- )
- def _rev_meta(a: TensorLikeType, dims: DimsSequenceType) -> TensorLikeType:
- utils.validate_dimension_indices(a.ndim, dims)
- out = torch.empty_like(a, memory_format=torch.preserve_format)
- return TensorMeta(out)
- _rev_doc = """
- Reverses the order of elements along the given dimensions.
- """
- rev = _make_prim(
- schema="rev(Tensor a, int[] dims) -> Tensor",
- meta=_rev_meta,
- impl_aten=torch.flip,
- return_type=RETURN_TYPE.NEW,
- doc=_rev_doc,
- )
- #
- # Conditional prims
- #
- def _where_meta(
- pred: TensorLikeType, a: TensorLikeType, b: TensorLikeType
- ) -> TensorLikeType:
- return _elementwise_meta(
- a,
- b,
- type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT,
- args_with_fixed_dtypes=(pred,),
- )
- _where_doc = """
- Selects elements from a and b according to pred.
- Where pred is true the result contains the element from a, and
- where pred is false the result contains the element from b.
- """
- where = _make_prim(
- schema="where(Tensor pred, Tensor a, Tensor b) -> Tensor",
- meta=_where_meta,
- impl_aten=torch.where,
- return_type=RETURN_TYPE.NEW,
- doc=_where_doc,
- )
- #
- # Type conversions
- #
- def _convert_element_type_meta(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
- # Type checks
- assert isinstance(a, TensorLike)
- assert isinstance(dtype, torch.dtype)
- # dtype conversion preserves dense strides
- if torch._prims_common.is_non_overlapping_and_dense(a):
- strides = a.stride()
- else:
- strides = utils.compute_elementwise_output_strides(a)
- return TensorMeta(a, strides=strides, dtype=dtype)
- def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
- # Propagates requires grad when possible
- if not utils.is_grad_dtype(dtype):
- requires_grad = False
- else:
- # TODO: update meta objects so this can be acquired directly
- try:
- requires_grad = a.requires_grad
- except Exception as e:
- requires_grad = False
- result = torch.empty_like(
- a, device=a.device, dtype=dtype, requires_grad=requires_grad
- )
- with torch.no_grad():
- return copy_to(result, a)
- _convert_element_type_doc = """
- Creates a copy of a tensor with the given dtype.
- """
- convert_element_type = _make_prim(
- schema="convert_element_type(Tensor a, ScalarType dtype) -> Tensor",
- meta=_convert_element_type_meta,
- impl_aten=_convert_element_type_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_convert_element_type_doc,
- )
- def _device_put_meta(
- a: TensorLikeType, device: Union[str, torch.device]
- ) -> TensorLikeType:
- assert isinstance(a, TensorLike)
- assert isinstance(device, (str, torch.device))
- return TensorMeta(a, device=utils.canonicalize_device(device))
- def _device_put_aten(a: Tensor, device: Union[str, torch.device]) -> Tensor:
- return a.to(device)
- _device_put_doc = """
- Creates a copy of a tensor on the given device.
- """
- device_put = _make_prim(
- schema="device_put(Tensor a, Device device) -> Tensor",
- meta=_device_put_meta,
- impl_aten=_device_put_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_device_put_doc,
- )
- # NOTE: need to model meta scalars
- # See https://github.com/pytorch/pytorch/issues/78070
- def _item_meta(a: TensorLikeType) -> FakeTensor:
- number_type = utils.dtype_to_type(a.dtype)
- return TensorMeta(number_type(-1))
- _item_doc = """
- Converts a tensor with one element to a Python number.
- """
- # TODO: create a new return type for scalars?
- # FIXME: currently returns integers for boolean tensors
- # https://github.com/pytorch/pytorch/issues/78071
- item = _make_prim(
- schema="item(Tensor a) -> Scalar",
- meta=_item_meta,
- impl_aten=torch.Tensor.item,
- return_type=RETURN_TYPE.NEW,
- doc=_item_doc,
- )
- # NOTE: need to model meta scalars
- # See https://github.com/pytorch/pytorch/issues/78070
- def _maximum_value_meta(dtype: torch.dtype) -> FakeTensor:
- number_type = utils.dtype_to_type(dtype)
- return TensorMeta(number_type(-1))
- def _maximum_value_aten(dtype: torch.dtype):
- if dtype == torch.bool:
- return True
- elif dtype.is_complex or dtype.is_floating_point:
- return torch.finfo(dtype).max
- else:
- return torch.iinfo(dtype).max
- _maximum_value_doc = """
- Return the maximum finite value for a dtype.
- """
- # TODO: create a new return type for scalars?
- # FIXME: currently returns integers for boolean tensors
- # https://github.com/pytorch/pytorch/issues/78071
- maximum_value = _make_prim(
- schema="maximum_value(ScalarType dtype) -> Scalar",
- meta=_maximum_value_meta,
- impl_aten=_maximum_value_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_maximum_value_doc,
- )
- # NOTE: need to model meta scalars
- # See https://github.com/pytorch/pytorch/issues/78070
- def _minimum_value_meta(dtype: torch.dtype) -> FakeTensor:
- number_type = utils.dtype_to_type(dtype)
- return TensorMeta(number_type(-1))
- def _minimum_value_aten(dtype: torch.dtype):
- if dtype == torch.bool:
- return False
- elif dtype.is_complex or dtype.is_floating_point:
- return torch.finfo(dtype).min
- else:
- return torch.iinfo(dtype).min
- _minimum_value_doc = """
- Return the mimimum finite value for a dtype.
- """
- # TODO: create a new return type for scalars?
- # FIXME: currently returns integers for boolean tensors
- # https://github.com/pytorch/pytorch/issues/78071
- minimum_value = _make_prim(
- schema="minium_value(ScalarType dtype) -> Scalar",
- meta=_minimum_value_meta,
- impl_aten=_minimum_value_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_minimum_value_doc,
- )
- #
- # Inplace operators
- #
- def _copy_to_meta(a: TensorLikeType, b: TensorLikeType):
- assert isinstance(a, TensorLike)
- assert isinstance(b, TensorLike)
- # Validates the cast is safe
- # TODO: move this as an option on the reference
- # a_typ = utils.dtype_to_type(a.dtype)
- # b_typ = utils.dtype_to_type(b.dtype)
- # if a_typ is not utils.get_higher_type(a_typ, b_typ):
- # raise RuntimeError(str(b.dtype), " can't be cast safely to ", str(a.dtype), "!")
- # Validates the tensors have the same number of elements
- if a.numel() != b.numel():
- msg = "Attempting to copy {0} elements to a tensor with {1} elements!".format(
- b.numel(), a.numel()
- )
- raise RuntimeError(msg)
- return a
- def _copy_to_aten(a: Tensor, b: Tensor) -> Tensor:
- return a.copy_(b)
- _copy_to_doc = """
- Copies the data in b to a and returns the modified a.
- """
- # TODO: Remove safe casting and implement on reference instead
- copy_to = _make_prim(
- schema="copy_to(Tensor(a!) a, Tensor b) -> Tensor(a!)",
- meta=_copy_to_meta,
- impl_aten=_copy_to_aten,
- return_type=RETURN_TYPE.INPLACE,
- doc=_copy_to_doc,
- )
- def _copy_strided_meta(a: TensorLikeType, stride: ShapeType):
- assert isinstance(a, TensorLike)
- return torch.empty_strided(
- a.shape,
- stride,
- dtype=a.dtype,
- layout=a.layout,
- device=a.device,
- requires_grad=a.requires_grad,
- )
- def _copy_strided_aten(a: Tensor, stride: ShapeType) -> Tensor:
- out = torch.empty_strided(
- a.size(),
- stride=stride,
- dtype=a.dtype,
- layout=a.layout,
- device=a.device,
- requires_grad=a.requires_grad,
- )
- out.copy_(a)
- return out
- _copy_strided_doc = """
- Copies the data in a to a new tensor, the new tensor has same shape with a size, but has different stride.
- """
- copy_strided = _make_prim(
- schema="copy_strided(Tensor a, SymInt[] stride) -> Tensor",
- meta=_copy_strided_meta,
- impl_aten=_copy_strided_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_copy_strided_doc,
- )
- def _resize_meta(a: TensorLikeType, shape: ShapeType):
- return a.resize_(shape)
- def _resize_aten(a: Tensor, shape: ShapeType) -> Tensor:
- return a.resize_(shape)
- _resize_doc = """
- Gives a tensor with no elements a new shape, returning the modified tensor.
- The tensor's strides are contiguous and its values are unitialized.
- """
- # TODO: review support arbitrary resizes
- resize = _make_prim(
- schema="resize(Tensor(a!) a, SymInt[] shape) -> Tensor(a!)",
- meta=_resize_meta,
- impl_aten=_resize_aten,
- return_type=RETURN_TYPE.INPLACE,
- doc=_resize_doc,
- )
- def _reduction_meta(inp, dims, *, output_dtype=None):
- """
- Meta function for single output reduction operations
- Stride logic is incorrect
- """
- assert isinstance(inp, TensorLike)
- if output_dtype is None:
- output_dtype = inp.dtype
- output_shape = utils.compute_reduction_output_shape(inp.shape, dims)
- return TensorMeta(
- shape=output_shape,
- strides=utils.make_contiguous_strides_for(output_shape),
- dtype=output_dtype,
- device=inp.device,
- )
- def _var_reduction_meta(inp, dims, *, correction):
- if utils.is_complex_dtype(inp.dtype):
- output_dtype = utils.corresponding_real_dtype(inp.dtype)
- else:
- output_dtype = inp.dtype
- return _reduction_meta(inp, dims, output_dtype=output_dtype)
- _sum_doc = """
- Computes the sum of elements in the input tensor over the list of dimensions
- specified in the dim argument
- """
- _prod_doc = """
- Computes the product of elements in the input tensor over the list of dimensions
- specified in the dim argument
- """
- _amax_doc = """
- Computes the maximum value of elements in the input tensor over the list of dimensions
- specified in the dim argument
- """
- _amin_doc = """
- Computes the minimum value of elements in the input tensor over the list of dimensions
- specified in the dim argument
- """
- _var_doc = """
- Computes the biased variance of x over the list of dimensions specified in the dim argument
- """
- def _make_reduction_prim(name: str, impl_aten, doc):
- """Creates a reduction prim."""
- return _make_prim(
- schema=f"{name}(Tensor inp, int[]? dims, *, ScalarType? output_dtype=None) -> Tensor",
- meta=_reduction_meta,
- impl_aten=impl_aten,
- return_type=RETURN_TYPE.NEW,
- doc=doc,
- )
- def _make_var_reduction_prim(name: str, impl_aten, doc):
- """Creates a reduction prim."""
- return _make_prim(
- schema=f"{name}(Tensor inp, int[]? dims, *, int correction, ScalarType? output_dtype=None) -> Tensor",
- meta=_var_reduction_meta,
- impl_aten=impl_aten,
- return_type=RETURN_TYPE.NEW,
- doc=doc,
- )
- sum = _make_reduction_prim(
- name="sum",
- impl_aten=torch.sum,
- doc=_sum_doc,
- )
- def _prod_aten(
- inp: TensorLikeType,
- dims: Optional[DimsSequenceType],
- *,
- dtype: Optional[torch.dtype] = None,
- ) -> Tensor:
- if dims is not None:
- for d in sorted(dims, reverse=True):
- assert d >= 0
- inp = torch.prod(inp, d, dtype=dtype)
- return inp
- else:
- return torch.prod(inp, dims, dtype=dtype)
- prod = _make_reduction_prim(
- name="prod",
- impl_aten=_prod_aten,
- doc=_prod_doc,
- )
- var = _make_var_reduction_prim(
- name="var",
- impl_aten=torch.var,
- doc=_var_doc,
- )
- amax = _make_reduction_prim(
- name="amax",
- impl_aten=torch.amax,
- doc=_amax_doc,
- )
- amin = _make_reduction_prim(
- name="amin",
- impl_aten=torch.amin,
- doc=_amin_doc,
- )
- _iota_doc = """
- Constructs a 1-D tensor t where ``t[i] == start + i * step``.
- """
- # TODO: layout, pin_memory, memory_format
- # TODO: model requires_grad on TensorMeta
- def _iota_meta(
- length: int,
- *,
- start: int,
- step: int,
- dtype: torch.dtype,
- device: torch.device,
- requires_grad: bool,
- ) -> TensorLikeType:
- utils.check(
- utils.is_integer_dtype(dtype),
- lambda: "prims.iota only supports integer dtypes",
- )
- utils.check(step != 0, lambda: "step must be nonzero")
- return torch.empty(
- length,
- dtype=dtype,
- device=device,
- requires_grad=requires_grad,
- )
- def _iota_aten(
- length: int,
- *,
- start: int,
- step: int,
- dtype: torch.dtype,
- device: torch.device,
- requires_grad: bool,
- ) -> TensorLikeType:
- end = start + length * step
- return torch.arange(
- start, end, step, dtype=dtype, device=device, requires_grad=requires_grad
- )
- iota = _make_prim(
- schema="iota(SymInt length, *, SymInt start, SymInt step, ScalarType dtype, Device device, bool requires_grad) -> Tensor", # noqa: B950
- return_type=RETURN_TYPE.NEW,
- meta=_iota_meta,
- impl_aten=_iota_aten,
- doc=_iota_doc,
- )
- # TODO: layout, pin_memory, memory_format
- # TODO: model requires_grad on TensorMeta
- def _empty_meta(
- shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
- ) -> TensorLikeType:
- strides = utils.make_contiguous_strides_for(shape)
- return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
- def _empty_aten(
- shape: ShapeType, *, dtype: torch.dtype, device: torch.device, requires_grad: bool
- ) -> Tensor:
- return torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
- _empty_doc = """
- Creates a tensor with uninitialized values and the specified shape, dtype, and device.
- """
- empty = _make_prim(
- schema="empty(SymInt[] shape, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
- meta=_empty_meta,
- impl_aten=_empty_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_empty_doc,
- )
- def _empty_strided_meta(
- shape: ShapeType,
- strides: StrideType,
- *,
- dtype: torch.dtype,
- device: torch.device,
- requires_grad: bool,
- ) -> TensorLikeType:
- return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
- _empty_strided_doc = """
- Creates a tensor with uninitialized values.
- """
- # TODO: add layout, pin_memory
- empty_strided = _make_prim(
- schema="empty_strided(SymInt[] shape, SymInt[] strides, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
- return_type=RETURN_TYPE.NEW,
- meta=_empty_strided_meta,
- impl_aten=torch.empty_strided,
- doc=_empty_strided_doc,
- )
- def _full_meta(
- shape: ShapeType,
- fill_value: NumberType,
- *,
- dtype: torch.dtype,
- device: torch.device,
- requires_grad: bool,
- ) -> TensorLikeType:
- strides = utils.make_contiguous_strides_for(shape)
- return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
- def _full_aten(
- shape: ShapeType,
- fill_value: NumberType,
- *,
- dtype: torch.dtype,
- device: torch.device,
- requires_grad: bool,
- ) -> Tensor:
- # Note that Mypy thinks torch.full can't accept a complex fill_value
- return torch.full(
- shape, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type]
- )
- _full_doc = """
- Creates a tensor filled with the given fill value, and with the specified shape, dtype, and device.
- """
- # TODO: add layout
- full = _make_prim(
- schema="full(SymInt[] shape, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
- meta=_full_meta,
- impl_aten=_full_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_full_doc,
- )
- def _full_like_meta(
- a: TensorLikeType,
- fill_value: NumberType,
- *,
- dtype: torch.dtype,
- device: torch.device,
- requires_grad: bool,
- ) -> TensorLikeType:
- strides = utils.compute_elementwise_output_strides(a)
- if a.numel() == 0:
- strides = a.stride()
- return TensorMeta(a, strides=strides, dtype=dtype, device=device)
- def _full_like_aten(
- a: Tensor,
- fill_value: NumberType,
- *,
- dtype: torch.dtype,
- device: torch.device,
- requires_grad: bool,
- ) -> Tensor:
- # Note that Mypy thinks torch.full can't accept a complex fill_value
- return torch.full_like(
- a, fill_value, dtype=dtype, device=device, requires_grad=requires_grad # type: ignore[arg-type]
- )
- _full_like_doc = """
- Creates a tensor filled with the given fill value, and the same shape, dtype, and device as the
- given tensor by default. The dtype and device settings can be overridden
- by specifying them explicitly.
- """
- full_like = _make_prim(
- schema="full_like(Tensor a, Scalar fill_value, *, ScalarType dtype, Device device, bool requires_grad) -> Tensor",
- meta=_full_like_meta,
- impl_aten=_full_like_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_full_like_doc,
- )
- def _scalar_tensor_meta(
- scalar: NumberType,
- *,
- dtype: torch.dtype,
- device: torch.device,
- ) -> TensorLikeType:
- shape: ShapeType = []
- strides = utils.make_contiguous_strides_for(shape)
- return TensorMeta(scalar, shape=shape, strides=strides, dtype=dtype, device=device)
- def _scalar_tensor_aten(
- scalar: NumberType,
- *,
- dtype: torch.dtype,
- device: torch.device,
- ) -> Tensor:
- if isinstance(scalar, complex) and (
- dtype is None or not utils.is_complex_dtype(dtype)
- ):
- raise TypeError("Complex scalar requires complex tensor dtype.")
- # Note that Mypy thinks torch.scalar can't accept a complex scalar
- return torch.scalar_tensor(scalar, dtype=dtype, device=device) # type: ignore[arg-type]
- _scalar_tensor_doc = """
- Wraps a Number into a Tensor with the specified dtype and device.
- """
- # TODO: add layout and pin_memory support
- scalar_tensor = _make_prim(
- schema="scalar_tensor(Scalar s, *, ScalarType? dtype=None, Device? device=None) -> Tensor",
- meta=_scalar_tensor_meta,
- impl_aten=_scalar_tensor_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_scalar_tensor_doc,
- )
- #
- # Linear algebra (linalg) prims
- #
- def _svd_meta(
- A: TensorLikeType, *, full_matrices: bool
- ) -> Tuple[TensorLikeType, TensorLikeType, TensorLikeType]:
- utils.check_is_matrix(A, "linalg.svd")
- utils.check_fp_or_complex(A.dtype, "linalg.svd", allow_low_precision_dtypes=False)
- A_shape = A.shape
- batch = A_shape[:-2]
- m, n = A_shape[-2:]
- k = min(m, n)
- shape_U = batch + (m, m if full_matrices else k)
- strides_U = utils.make_contiguous_strides_for(shape_U, row_major=False)
- U = TensorMeta(shape=shape_U, strides=strides_U, dtype=A.dtype, device=A.device)
- shape_S = batch + (k,)
- strides_S = utils.make_contiguous_strides_for(shape_S)
- S = TensorMeta(
- shape=shape_S,
- strides=strides_S,
- dtype=utils.corresponding_real_dtype(A.dtype) if A.is_complex() else A.dtype,
- device=A.device,
- )
- shape_Vh = batch + (n if full_matrices else k, n)
- # The CPU backend returns V, but the cuSolver backend returns V^H
- # TODO The MAGMA backend returns V, so this is wrong if used with the MAGMA backend
- is_cuda = A.device.type == "cuda"
- strides_Vh = utils.make_contiguous_strides_for(shape_Vh, row_major=is_cuda)
- Vh = TensorMeta(shape=shape_Vh, strides=strides_Vh, dtype=A.dtype, device=A.device)
- return U, S, Vh
- def _svd_aten(
- A: TensorLikeType, *, full_matrices: bool
- ) -> Tuple[Tensor, Tensor, Tensor]:
- return torch.linalg.svd(A, full_matrices=full_matrices)
- _svd_doc = """
- Returns the SVD of a matrix or batch of matrices.
- The `full_matrices` flag controls whether the full or reduced SVD decomposition is returned.
- """
- svd = _make_prim(
- schema="svd(Tensor A, *, bool full_matrices) -> (Tensor U, Tensor S, Tensor Vh)",
- meta=_svd_meta,
- impl_aten=_svd_aten,
- return_type=(RETURN_TYPE.NEW, RETURN_TYPE.NEW, RETURN_TYPE.NEW),
- doc=_svd_doc,
- )
- #
- # Randomness Prims
- #
- # TODO: add generator support
- # NOTE: there is currently no way of acquiring the "default" torch generator
- def _normal_meta(
- shape: ShapeType,
- *,
- mean: Union[float, complex],
- std: float,
- dtype: torch.dtype,
- device: torch.device,
- requires_grad: bool,
- ) -> TensorLikeType:
- utils.check(
- std >= 0.0,
- lambda: f"expected non-negative standard deviation, but got std={std}",
- )
- utils.check(
- utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
- lambda: f"expected a floating-point or complex dtype, but got dtype={dtype}",
- )
- strides = utils.make_contiguous_strides_for(shape)
- return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
- def _normal_aten(
- shape: ShapeType,
- *,
- mean: Union[float, complex],
- std: float,
- dtype: torch.dtype,
- device: torch.device,
- requires_grad: bool,
- ) -> Tensor:
- a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
- with torch.no_grad():
- # NOTE: normal_ is incorrectly annotated to expect mean to be a float
- a.normal_(mean, std) # type: ignore[arg-type]
- return a
- _normal_doc = """
- Constructs a tensor filled with values drawn from a normal distribution with the specified mean
- and standard deviation.
- Only supports floating-point types.
- """
- normal = _make_prim(
- schema=(
- "normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad) -> Tensor"
- ),
- return_type=RETURN_TYPE.NEW,
- meta=_normal_meta,
- impl_aten=_normal_aten,
- doc=_normal_doc,
- )
- def _uniform_meta(
- shape: ShapeType,
- *,
- low: float,
- high: float,
- dtype: torch.dtype,
- device: torch.device,
- ) -> TensorLikeType:
- strides = utils.make_contiguous_strides_for(shape)
- return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
- def _uniform_aten(
- shape: ShapeType,
- *,
- low: float,
- high: float,
- dtype: torch.dtype,
- device: torch.device,
- ) -> Tensor:
- a = torch.empty(shape, dtype=dtype, device=device)
- a.uniform_(low, high)
- return a
- _uniform_doc = """
- Constructs a tensor filled with values drawn uniformly from low to high.
- """
- # TODO: we should more seriously review randomness modeling and prims
- _uniform_helper = _make_prim(
- schema=(
- "uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device) -> Tensor"
- ),
- return_type=RETURN_TYPE.NEW,
- meta=_uniform_meta,
- impl_aten=_uniform_aten,
- doc=_uniform_doc,
- )
- #
- # FFT prims
- #
- def _fft_r2c_meta(
- input: TensorLike,
- *,
- dim: DimsSequenceType,
- onesided: bool,
- ) -> TensorLikeType:
- dim = utils.canonicalize_dims(input.ndim, dim)
- utils.validate_no_repeating_dims(dim)
- shape = list(input.shape)
- if onesided:
- last_dim = dim[-1]
- shape[last_dim] = shape[last_dim] // 2 + 1
- dtype = utils.corresponding_complex_dtype(input.dtype)
- strides = utils.make_contiguous_strides_for(shape)
- return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
- def _fft_r2c_aten(
- input: TensorLike,
- *,
- dim: DimsSequenceType,
- onesided: bool,
- ) -> TensorLikeType:
- normalization = 0 # No normalization
- return torch._fft_r2c(input, dim, normalization, onesided)
- _fft_r2c_doc = """
- Performs a real to complex Fast Fourier Transform
- """
- fft_r2c = _make_prim(
- schema="fft_r2c(Tensor self, *, int[] dim, bool onesided) -> Tensor",
- meta=_fft_r2c_meta,
- impl_aten=_fft_r2c_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_fft_r2c_doc,
- )
- def _fft_c2c_meta(
- input: TensorLike,
- *,
- dim: DimsSequenceType,
- forward: bool,
- ) -> TensorLikeType:
- dim = utils.canonicalize_dims(input.ndim, dim)
- utils.validate_no_repeating_dims(dim)
- shape = input.shape
- strides = utils.make_contiguous_strides_for(shape)
- return TensorMeta(
- shape=shape, strides=strides, dtype=input.dtype, device=input.device
- )
- def _fft_c2c_aten(
- input: TensorLike,
- *,
- dim: DimsSequenceType,
- forward: bool,
- ) -> TensorLikeType:
- normalization = 0 # No normalization
- return torch._fft_c2c(input, dim, normalization, forward)
- _fft_c2c_doc = """
- Performs either a Fast Fourier Transform, or its inverse
- """
- fft_c2c = _make_prim(
- schema="fft_c2c(Tensor self, *, int[] dim, bool forward) -> Tensor",
- meta=_fft_c2c_meta,
- impl_aten=_fft_c2c_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_fft_c2c_doc,
- )
- def _fft_c2r_meta(
- input: TensorLike,
- *,
- dim: DimsSequenceType,
- last_dim_size: int,
- ) -> TensorLikeType:
- dim = utils.canonicalize_dims(input.ndim, dim)
- utils.validate_no_repeating_dims(dim)
- shape = list(input.shape)
- shape[dim[-1]] = last_dim_size
- dtype = utils.corresponding_real_dtype(input.dtype)
- strides = utils.make_contiguous_strides_for(shape)
- return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=input.device)
- def _fft_c2r_aten(
- input: TensorLike,
- *,
- dim: DimsSequenceType,
- last_dim_size: int,
- ) -> TensorLikeType:
- normalization = 0 # No normalization
- return torch._fft_c2r(input, dim, normalization, last_dim_size)
- _fft_c2r_doc = """
- Performs a complex to real Inverse Fast Fourier Transform
- """
- fft_c2r = _make_prim(
- schema="fft_c2r(Tensor self, *, int[] dim, SymInt last_dim_size) -> Tensor",
- meta=_fft_c2r_meta,
- impl_aten=_fft_c2r_aten,
- return_type=RETURN_TYPE.NEW,
- doc=_fft_c2r_doc,
- )
- register_nvprims()
|