1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690 |
- from __future__ import annotations
- from typing import Any, Union, Sequence, Optional, Tuple, List, Callable, Type, overload, cast
- from enum import Enum
- from functools import reduce, cmp_to_key
- import operator
- import weakref
- import torch
- from torch import sym_float, sym_int, sym_max
- try:
- from nvfuser._C import DataType # type: ignore[import]
- _torch_dtype_to_nvfuser_dtype_map = {
- torch.cdouble: DataType.ComplexDouble,
- torch.cfloat: DataType.ComplexFloat,
- torch.double: DataType.Double,
- torch.float: DataType.Float,
- torch.half: DataType.Half,
- torch.bfloat16: DataType.BFloat16,
- torch.long: DataType.Int,
- torch.int: DataType.Int32,
- torch.uint8: DataType.Int32,
- torch.bool: DataType.Bool,
- # Python scalars
- complex: DataType.ComplexDouble,
- float: DataType.Double,
- int: DataType.Int,
- bool: DataType.Bool,
- }
- except ImportError:
- _torch_dtype_to_nvfuser_dtype_map = {}
- def getnvFuserDtype(dtype: Union[torch.dtype, NumberTypeType]):
- """
- Translates from torch.dtype to nvFuser's DataType enum
- """
- return _torch_dtype_to_nvfuser_dtype_map[dtype]
- ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
- StrideType = Union[List[int], Tuple[int, ...]]
- DimsType = Union[int, List[int], Tuple[int, ...]]
- DimsSequenceType = Union[List[int], Tuple[int, ...]]
- # TODO: Type[torch.SymInt], Type[torch.SymFloat]
- NumberTypeType = Union[Type[bool], Type[int], Type[float], Type[complex]]
- # TODO: This needs a lot more type annotations
- # NumberType = Union[bool, int, float, complex, torch.SymInt, torch.SymFloat]
- NumberType = Union[bool, int, float, complex]
- Number = (bool, int, float, complex, torch.SymInt, torch.SymFloat)
- # I don't call it Integral because numbers.Integral includes bool, but IntLike
- # does not
- Dim = int
- IntLike = (int, torch.SymInt)
- FloatLike = (float, torch.SymFloat)
- IntWithoutSymInt = int
- FloatWithoutSymFloat = float
- DeviceLikeType = Union[str, torch.device]
- Tensor = torch.Tensor
- torch_function_passthrough = {
- torch.Tensor.dim,
- torch.Tensor.ndim.__get__, # type: ignore[attr-defined]
- torch.Tensor.numel,
- torch.Tensor.size,
- torch.Tensor.storage_offset,
- torch.Tensor.stride,
- torch.Tensor.dtype.__get__, # type: ignore[attr-defined]
- torch.Tensor.is_sparse.__get__, # type: ignore[attr-defined]
- torch.Tensor.shape.__get__, # type: ignore[attr-defined]
- torch.Tensor.device.__get__, # type: ignore[attr-defined]
- torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]
- torch.Tensor.layout.__get__, # type: ignore[attr-defined]
- # For TorchRefsMode only
- torch.Tensor.__format__,
- torch.Tensor.__repr__,
- torch.Tensor.requires_grad.__get__, # type: ignore[attr-defined]
- }
- TensorLikeType = torch.Tensor
- TensorLike = torch.Tensor
- TensorSequenceType = Union[List[TensorLikeType], Tuple[TensorLikeType, ...]]
- TensorOrNumberLikeType = Union[TensorLikeType, NumberType]
- def same_shape(a: ShapeType, b: ShapeType) -> bool:
- if len(a) != len(b):
- return False
- for x, y in zip(a, b):
- if x != y:
- return False
- return True
- # TODO: look at using torch.testing.assert_close instead with an option
- # to just compare metadata
- def compare_tensor_meta(a: TensorLikeType, b: TensorLikeType, check_strides=False):
- """
- Checks that two tensor likes have the same shape,
- dtype and device.
- In the future this will validate additional metadata, like
- strides.
- """
- assert isinstance(a, TensorLike)
- assert isinstance(b, TensorLike)
- if not same_shape(a.shape, b.shape):
- msg = "Shapes {0} and {1} are not equal!".format(a.shape, b.shape)
- raise AssertionError(msg)
- if a.dtype != b.dtype:
- msg = "Dtypes {0} and {1} are not equal!".format(a.dtype, b.dtype)
- raise AssertionError(msg)
- if a.device != b.device:
- # Handles special cuda:0 vs cuda case
- # TODO: we should review why this happens and see about fixing it
- if (str(a.device) == "cuda:0" or str(a.device) == "cuda") and (
- str(b.device) == "cuda:0" or str(b.device) == "cuda"
- ):
- pass
- else:
- msg = "Devices {0} and {1} are not equal!".format(a.device, b.device)
- raise AssertionError(msg)
- # Stride checking is currently disabled, see https://github.com/pytorch/pytorch/issues/78050
- if check_strides:
- same_strides, idx = check_significant_strides(a, b)
- if not same_strides:
- msg = (
- "Stride mismatch! Strides are {0} and {1} (mismatched at {2})!".format(
- a.stride(), b.stride(), idx
- )
- )
- raise RuntimeError(msg)
- if a.storage_offset() != b.storage_offset():
- msg = (
- "Storage offset mismatch! Storage offsets are {0} and {1}!".format(
- a.storage_offset(), b.storage_offset()
- )
- )
- raise RuntimeError(msg)
- def _check_strides_helper(
- a: TensorLikeType, b: TensorLikeType, *, only_cuda=True, significant_only=True
- ) -> Tuple[bool, Optional[int]]:
- # NOTE: only on CUDA because CPU elementwise strides are incorrect in PyTorch
- # See https://github.com/pytorch/pytorch/issues/77553
- # Only compares strides that are "meaningful" -- strides for dimensions with length > 1
- # and for tensors with more than one element
- if (not only_cuda or a.device.type == "cuda" or b.device.type == "cuda") and a.numel() > 0:
- for idx in range(a.ndim):
- check = not significant_only or a.shape[idx] > 1
- if a.stride()[idx] != b.stride()[idx] and check:
- return False, idx
- return True, None
- def check_significant_strides(
- a: TensorLikeType, b: TensorLikeType, *, only_cuda=True
- ) -> Tuple[bool, Optional[int]]:
- return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=True)
- def check_all_strides(
- a: TensorLikeType, b: TensorLikeType, *, only_cuda=True
- ) -> Tuple[bool, Optional[int]]:
- return _check_strides_helper(a, b, only_cuda=only_cuda, significant_only=False)
- # This function is equivalent to compute_contiguous() from TensorImpl.cpp
- def is_contiguous(a: TensorLikeType) -> bool:
- """
- Tests whether a tensor is contiguous or not.
- Tensors are contiguous when they have no elements,
- one element, or when they have "nested" strides.
- """
- if a.numel() < 2:
- return True
- expected_stride = 1
- for x, y in reversed(tuple(zip(a.shape, a.stride()))):
- # Skips checking strides when a dimension has length 1
- if x == 1:
- continue
- if y != expected_stride:
- return False
- expected_stride = expected_stride * x
- return True
- # This function is equivalent to compute_channels_last_contiguous_2d() in TensorImpl.cpp
- def is_channels_last_contiguous_2d(a: Tensor) -> bool:
- # NHWC or not channels last 2D contiguous
- if a.ndim != 4:
- return False
- expected_stride = 1
- for idx in (1, 3, 2, 0):
- length = a.shape[idx]
- if length == 1:
- continue
- stride = a.stride()[idx]
- if stride != expected_stride:
- return False
- expected_stride *= length
- return True
- def is_channels_last_contiguous_3d(a: Tensor) -> bool:
- # NDHWC or not channels last 3D contiguous
- if a.ndim != 5:
- return False
- expected_stride = 1
- for idx in (1, 4, 3, 2, 0):
- length = a.shape[idx]
- if length == 1:
- continue
- stride = a.stride()[idx]
- if stride != expected_stride:
- return False
- expected_stride *= length
- return True
- _memory_formats = {
- torch.contiguous_format,
- torch.preserve_format,
- torch.channels_last,
- torch.channels_last_3d,
- }
- def validate_memory_format(memory_format: torch.memory_format):
- check(
- memory_format in _memory_formats,
- lambda: f"Received unknown memory format {memory_format}!",
- )
- def is_contiguous_for_memory_format( # type: ignore[return]
- a: Tensor, *, memory_format: torch.memory_format
- ) -> bool:
- validate_memory_format(memory_format)
- if memory_format == torch.contiguous_format:
- return is_contiguous(a)
- if memory_format == torch.channels_last:
- return is_channels_last_contiguous_2d(a)
- if memory_format == torch.channels_last_3d:
- return is_channels_last_contiguous_3d(a)
- check(
- False,
- lambda: f"is_contiguous received unsupported memory format {memory_format}",
- )
- # NOTE: that tensors with no elements and channels last is ???
- def is_channels_last_contiguous(a: Tensor) -> bool:
- """
- True when a tensor is channels-last contiguous.
- This requires that:
- - the tensor is conceptually either 4 (NHWC) or 5 (NDHWC) dimensions
- - if we name the tensor's dimensions NCHW or NCDHW, then the strides are such that the
- stride of the 'C' dimension (Cs) is 1 and the strides corresponding to
- each dimension (Xs) can be ordered Cs <= Ws <= Hs <= (Ds) <= Ns and are
- "nested" -- so Ws = Cs * Cl, where Cl is the length of the 'C' dimension,
- for example.
- """
- return is_channels_last_contiguous_2d(a) or is_channels_last_contiguous_3d(a)
- def is_non_overlapping_and_dense(a: Tensor) -> bool:
- """
- True when a tensor is non-overlapping and dense.
- A tensor is non-overlapping and dense when there exists a permutation of
- its dimensions that is contiguous.
- """
- if a.is_sparse:
- return False
- # Short-circuits if the tensor is already contiguous or channels-last contiguous
- if is_contiguous(a) or is_channels_last_contiguous(a):
- return True
- # The following is equivalent to compute_non_overlapping_and_dense in TensorImpl.cpp
- # Short-circuits for tensors of rank one, which are
- # non-overlapping and "dense" if their stride is one
- if a.ndim == 1:
- return a.stride()[0] == 1
- # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous
- # Sorts (length, stride) pairs by stride
- lengths_and_strides = sorted(
- zip(a.shape, a.stride()), key=operator.itemgetter(1)
- )
- expected_stride = 1
- for length, stride in lengths_and_strides:
- if length == 1:
- continue
- if stride != expected_stride:
- return False
- expected_stride *= length
- return True
- # NOTE: Based on the implementation in TensorIterator.cpp, but note that
- # the note [Computing output strides] is incorrect, because it
- # says that strides will be preserved even if they are not
- # "non overlapping and dense", but this is incorrect. The
- # output of elementwise operations are always given
- # non overlapping and dense strides.
- # This is also INCORRECT because it does not model TensorIterator's
- # short-circuit, which can cause different strides.
- def compute_elementwise_output_strides(*tensors) -> Tuple[int, ...]:
- """
- Computes the output strides for elementwise operations.
- """
- if len(tensors) == 0:
- msg = "Can't compute elementwise output strides for zero tensors!"
- raise ValueError(msg)
- check_same_shape(*tensors, allow_cpu_scalar_tensors=True)
- # Filters the tensors to actual tensors
- tensors = tuple(
- a for a in tensors if isinstance(a, TensorLike) and not is_cpu_scalar_tensor(a)
- )
- # Short-circuits for CPU scalar case
- if len(tensors) == 0:
- return ()
- # Short-circuits for shapes with zero or one dimensions
- # TODO: are these necessary?
- ndim = tensors[0].ndim
- if ndim == 0:
- return ()
- if ndim == 1:
- return (1,)
- shape = tensors[0].shape
- def should_swap(idx_a, idx_b):
- for tensor in tensors:
- stride_a = tensor.stride()[idx_a]
- stride_b = tensor.stride()[idx_b]
- if stride_a == 0 or stride_b == 0:
- continue
- if stride_a < stride_b:
- return -1
- if stride_a > stride_b:
- return 1
- # stride_a == stride_b
- if shape[idx_a] > shape[idx_b]:
- return 1
- # Note: this case is hit if all strides are zero,
- # or all strides are equal and all dimensions have the same length
- return 0
- perm = list(reversed(range(ndim)))
- # insertion sort with support for ambiguous comparisons
- for i in range(1, ndim):
- dim1 = i
- for dim0 in reversed(range(i)):
- comparison = should_swap(perm[dim0], perm[dim1])
- if comparison > 0:
- perm[dim0], perm[dim1] = perm[dim1], perm[dim0]
- dim1 = dim0
- elif comparison < 0:
- break
- permuted_shape = [-1] * ndim
- for idx, x in enumerate(reversed(perm)):
- permuted_shape[idx] = shape[x]
- new_strides = make_contiguous_strides_for(permuted_shape)
- permuted_strides = [-1] * ndim
- for idx, x in enumerate(reversed(perm)):
- permuted_strides[x] = new_strides[idx]
- return tuple(permuted_strides)
- #
- # Common helper functions
- #
- def validate_dim_length(length: int):
- """
- Validates that an object represents a valid
- dimension length.
- """
- assert length >= 0
- def validate_shape(shape: ShapeType):
- """
- Validates that a sequence represents a valid shape.
- """
- assert isinstance(shape, Sequence)
- for l in shape:
- validate_dim_length(l)
- def validate_strides(strides: StrideType):
- """
- Verifies the object specifies valid strides.
- """
- assert isinstance(strides, Sequence)
- for stride in strides:
- assert stride >= 0
- def validate_idx(rank: int, idx: int):
- """
- Validates that idx is a valid index for the given shape.
- Assumes the index is already canonicalized.
- """
- assert isinstance(idx, Dim)
- assert isinstance(rank, Dim)
- assert idx >= 0 and idx < rank or idx == 0
- def validate_dimension_indices(rank: int, indices: DimsSequenceType):
- for idx in indices:
- validate_idx(rank, idx)
- def validate_exclusive_idx(rank: int, ex_idx: int):
- """
- Validates that ex_idx is a valid exclusive index
- for the given shape.
- """
- assert isinstance(ex_idx, Dim)
- assert isinstance(rank, Dim)
- assert ex_idx > 0 and ex_idx <= rank
- # "Wraps" a dim (up to one time) for the given rank, allowing dims to be
- # specified using negative indices. If `wrap_scalar` is true then scalar
- # tensors of rank 0 will allow dimensions in the range [-1, 0]. Otherwise,
- # idx should be in the range [-rank, rank-1].
- def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
- if rank < 0:
- msg = f"Rank cannot be negative but got {rank}"
- raise IndexError(msg)
- if rank == 0:
- if not wrap_scalar:
- msg = f"Dimension specified as {idx} but tensor has no dimensions"
- raise IndexError(msg)
- rank = 1
- if idx >= 0 and idx < rank:
- return idx
- if idx < 0:
- _idx = idx + rank
- else:
- _idx = idx
- if _idx < 0 or _idx >= rank:
- # Same error message as in aten/src/ATen/WrapDimUtils.h:49
- msg = "Dimension out of range (expected to be in range of [{0}, {1}], but got {2})".format(
- -rank, rank - 1, idx
- )
- raise IndexError(msg)
- return _idx
- # Takes a dimension or sequence of dimensions and "wraps" them,
- # mapping negative offsets to positive ones
- @overload
- def canonicalize_dims(rank: int, indices: Sequence[int], wrap_scalar: bool = True) -> Tuple[int, ...]:
- pass
- @overload
- def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int:
- pass
- def canonicalize_dims(rank, indices, wrap_scalar=True):
- if isinstance(indices, Dim):
- return canonicalize_dim(rank, indices, wrap_scalar)
- return tuple(canonicalize_dim(rank, x, wrap_scalar) for x in indices)
- def is_valid_permutation(rank: int, perm: DimsSequenceType) -> bool:
- """
- Validates that perm is a permutation of length rank.
- """
- if not isinstance(perm, Sequence):
- return False
- if not (tuple(sorted(perm)) == tuple(range(0, rank))):
- return False
- return True
- def is_same_shape(a: Sequence, b: Sequence) -> bool:
- """
- Compares two shapes a and b, returning True if they are the same
- (their ranks and corresponding lengths match) and False otherwise.
- """
- return tuple(a) == tuple(b)
- def is_cpu_scalar_tensor(a: Any) -> bool:
- return isinstance(a, TensorLike) and a.ndim == 0 and a.device.type == "cpu"
- def check_same_device(*args, allow_cpu_scalar_tensors):
- """
- Checks that all Tensors in args have the same device.
- Raises a RuntimeError when:
- - args contains an object whose type is not Tensor or Number
- - two Tensor objects in args have different devices, unless one is a CPU scalar tensor and allow_cpu_scalar_tensors is True
- """
- # Short-circuits if all (one or fewer) arguments are trivially on the same device
- if len(args) <= 1:
- return
- # Note: cannot initialize device to the first arg's device (it may not have one)
- device = None
- for arg in args:
- if isinstance(arg, Number):
- continue
- elif isinstance(arg, TensorLike):
- if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
- continue
- if device is None:
- device = arg.device
- if device != arg.device:
- msg = (
- "Tensor on device "
- + str(arg.device)
- + " is not on the expected device "
- + str(device)
- + "!"
- )
- raise RuntimeError(msg)
- else:
- msg = (
- "Unexpected type when checking for same device, " + str(type(arg)) + "!"
- )
- raise RuntimeError(msg)
- def canonicalize_device(device: DeviceLikeType) -> torch.device:
- if isinstance(device, torch.device):
- return device
- assert isinstance(device, str)
- return torch.device(device)
- # Asserts if any of the following are true:
- # - a non-scalar or non-Tensor is given
- # - the shape of any tensors is distinct
- def check_same_shape(*args, allow_cpu_scalar_tensors: bool):
- """
- Checks that all Tensors in args have the same shape.
- Raises a RuntimeError when:
- - args contains an object whose type is not Tensor or Number
- - two Tensor objects in args have different devices
- """
- shape = None
- for arg in args:
- if isinstance(arg, Number):
- continue
- elif isinstance(arg, TensorLike):
- if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
- continue
- if shape is None:
- shape = arg.shape
- if not is_same_shape(shape, arg.shape):
- msg = "Shape {0} is not the expected shape {1}!".format(
- arg.shape, shape
- )
- raise RuntimeError(msg)
- else:
- msg = (
- "Unexpected type when checking for same shape, " + str(type(arg)) + "!"
- )
- raise RuntimeError(msg)
- # Acquires a common shape, if it exists, from one or more tensor arguments,
- # filtering number arguments
- def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
- shape = None
- scalar_shape = None
- for arg in args:
- if isinstance(arg, Number):
- continue
- elif isinstance(arg, TensorLike):
- if allow_cpu_scalar_tensors and is_cpu_scalar_tensor(arg):
- scalar_shape = arg.shape
- continue
- if shape is None:
- shape = arg.shape
- if not is_same_shape(shape, arg.shape):
- return None
- else:
- return None
- return shape if shape is not None else scalar_shape
- # Extracts dimensions that might be passed either as a list/tuple or as varargs.
- # A typical case is Tensor.permute .
- def extract_dims_from_varargs(dims: Union[DimsSequenceType, Tuple[DimsSequenceType, ...]]) -> DimsSequenceType:
- if dims and isinstance(dims[0], Sequence):
- assert len(dims) == 1
- dims = cast(Tuple[DimsSequenceType], dims)
- return dims[0]
- else:
- return cast(DimsSequenceType, dims)
- def extract_shape_from_varargs(
- shape: Union[ShapeType, Tuple[ShapeType]],
- validate=True,
- ) -> Tuple[int, ...]:
- """
- Returns a shape from varargs.
- In PyTorch, operations that accept shapes often accept them as varargs, like
- foo(*shape). However a user can pass the shape as a sequence of integers,
- like this:
- foo(1, 2, 3)
- or as a sequence of integers
- foo((1, 2, 3))
- In the first case shape will be a tuple of integers, and in the second case it's a tuple
- containing a tuple of integers. This validates those inputs and canonicalizes them
- to a tuple of integers.
- """
- # Handles tuple unwrapping
- if len(shape) == 1 and isinstance(shape[0], Sequence):
- shape = shape[0]
- if validate:
- validate_shape(shape) # type: ignore[arg-type]
- return shape # type: ignore[return-value]
- def infer_size(shape: ShapeType, numel: int) -> Tuple[int, ...]:
- """
- Infers the size of a dim with size -1, if it exists.
- Also checks that new shape is compatible with the number of elements.
- """
- dim = None
- newsize = 1
- for i, d in enumerate(shape):
- if d == -1:
- check(dim is None, lambda: "only one dimension can be inferred")
- dim = i
- elif d >= 0:
- newsize *= d
- else:
- check(False, lambda: f"invalid shape dimension {d}")
- check(
- numel == newsize or (dim is not None and newsize > 0 and numel % newsize == 0),
- lambda: f"shape '{list(shape)}' is invalid for input of size {numel}",
- )
- if dim is not None:
- # Convert to list to produce a compatible error message with core
- # PyTorch, which prints sequences in square brackets.
- shape = list(shape)
- check(
- newsize != 0,
- lambda: (f"cannot reshape tensor of 0 elements into shape {shape} because the "
- f"unspecified dimension size -1 can be any value and is ambiguous"),
- )
- shape[dim] = numel // newsize
- return tuple(shape)
- _integer_dtypes = (torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
- _low_precision_dtypes = (torch.float16, torch.bfloat16, torch.complex32)
- _float_dtypes = (torch.float16, torch.bfloat16, torch.float32, torch.float64)
- _complex_dtypes = (torch.complex32, torch.complex64, torch.complex128)
- def is_boolean_dtype(dtype: torch.dtype) -> bool:
- assert isinstance(dtype, torch.dtype)
- return dtype is torch.bool
- def is_integer_dtype(dtype: torch.dtype) -> bool:
- assert isinstance(dtype, torch.dtype)
- return dtype in _integer_dtypes
- def is_low_precision_dtype(dtype: torch.dtype) -> bool:
- assert isinstance(dtype, torch.dtype)
- return dtype in _low_precision_dtypes
- def is_float_dtype(dtype: torch.dtype) -> bool:
- assert isinstance(dtype, torch.dtype)
- return dtype in _float_dtypes
- def is_complex_dtype(dtype: torch.dtype) -> bool:
- assert isinstance(dtype, torch.dtype)
- return dtype in _complex_dtypes
- def is_grad_dtype(dtype: torch.dtype) -> bool:
- """
- Checks if the dtype can require a gradient.
- """
- return is_float_dtype(dtype) or is_complex_dtype(dtype)
- _complex_to_real_dtype_map = {
- torch.complex128: torch.float64,
- torch.complex64: torch.float32,
- torch.complex32: torch.float16,
- }
- _real_to_complex_dtype_map = {
- torch.float16: torch.complex32,
- torch.bfloat16: torch.complex64,
- torch.float32: torch.complex64,
- torch.float64: torch.complex128,
- }
- def corresponding_real_dtype(dtype: torch.dtype) -> torch.dtype:
- return _complex_to_real_dtype_map[dtype]
- def corresponding_complex_dtype(dtype: torch.dtype) -> torch.dtype:
- return _real_to_complex_dtype_map[dtype]
- def dtype_to_type(dtype: torch.dtype) -> type:
- """
- Computes the corresponding Python type (AKA "type kind") for the
- given dtype.
- """
- assert isinstance(dtype, torch.dtype)
- if dtype is torch.bool:
- return bool
- if dtype in _integer_dtypes:
- return int
- if dtype in _float_dtypes:
- return float
- if dtype in _complex_dtypes:
- return complex
- raise ValueError("Invalid dtype!")
- def dtype_to_type_ctor(dtype: torch.dtype) -> Callable[[NumberType], NumberType]:
- """
- Computes the corresponding Python type constructor for the
- given dtype.
- """
- assert isinstance(dtype, torch.dtype)
- if dtype is torch.bool:
- return lambda x: bool(x)
- if dtype in _integer_dtypes:
- return sym_int
- if dtype in _float_dtypes:
- return sym_float
- if dtype in _complex_dtypes:
- # TODO: type error here is real, replace with sym_complex
- return lambda x: complex(x) # type: ignore[arg-type]
- raise ValueError("Invalid dtype!")
- def type_to_dtype(typ: type) -> torch.dtype:
- """
- Computes the corresponding dtype for a Number type.
- """
- assert isinstance(typ, type)
- if typ is bool:
- return torch.bool
- if typ in [int, torch.SymInt]:
- return torch.long
- if typ in [float, torch.SymFloat]:
- return torch.get_default_dtype()
- # TODO: sym_complex_float?
- if typ is complex:
- return corresponding_complex_dtype(torch.get_default_dtype())
- raise ValueError("Invalid type!")
- def get_dtype(x: Union[torch.Tensor, NumberType]):
- if isinstance(x, torch.Tensor):
- return x.dtype
- else:
- return type_to_dtype(type(x))
- _ordered_types = (bool, int, float, complex)
- def check_fp_or_complex(
- dtype: torch.dtype, fn_name: str, allow_low_precision_dtypes: bool = True
- ):
- """
- Checks whether the input is floating point or complex.
- If allow_low_precision_dtypes is True, it allows having float16, bfloat16, and complex32
- """
- check(
- is_float_dtype(dtype) or is_complex_dtype(dtype),
- lambda: f"{fn_name}: Expected a floating point or complex tensor as input. Got {dtype}",
- )
- check(
- allow_low_precision_dtypes or not is_low_precision_dtype(dtype),
- lambda: f"{fn_name}: Half precision dtypes not supported. Got {dtype}",
- )
- def check_is_matrix(A: TensorLikeType, f_name: str, arg_name: str = "A"):
- check(
- len(A.shape) >= 2,
- lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
- )
- def get_higher_type(a: type, b: type) -> type:
- """
- Returns the higher of the two given Number types.
- The types are ordered bool -> int -> float -> complex.
- """
- # Type checking
- assert a in _ordered_types
- assert b in _ordered_types
- if a is b:
- return a
- for typ in _ordered_types:
- if a is typ:
- return b
- if b is typ:
- return a
- raise ValueError("Unknown Python scalar type!")
- # Returns the higher of two torch datatypes a and b or, if the two
- # are not ordered relative to each other, the next
- # higher datatype
- def get_higher_dtype(
- a: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
- b: Optional[Union[torch.dtype, TensorLikeType, NumberType]],
- ) -> Optional[torch.dtype]:
- """
- Computes the "lowest" datatype that is weakly
- "higher" than both a and b.
- """
- # Type checking
- assert a is None or isinstance(a, (torch.dtype, TensorLike, Number))
- assert b is None or isinstance(b, (torch.dtype, TensorLike, Number))
- def _extract_dtype(
- x: Optional[Union[torch.dtype, TensorLikeType, NumberType]]
- ) -> Optional[torch.dtype]:
- if x is None:
- return None
- if isinstance(x, torch.dtype):
- return x
- if isinstance(x, TensorLike):
- return x.dtype
- if isinstance(x, Number):
- return type_to_dtype(type(x))
- raise RuntimeError("Unexpected type given to _extract_dtype!")
- a, b = _extract_dtype(a), _extract_dtype(b)
- if a is b:
- return a
- if a is None:
- return b
- if b is None:
- return a
- ordered_datatypes = (
- (torch.bool,),
- (torch.uint8, torch.int8),
- (torch.int16,),
- (torch.int32,),
- (torch.int64,),
- (torch.float16, torch.bfloat16),
- (torch.float32,),
- (torch.float64,),
- (torch.complex32,),
- (torch.complex64,),
- (torch.complex128,),
- )
- for idx, dtypes in enumerate(ordered_datatypes):
- if a in dtypes and b in dtypes:
- return ordered_datatypes[idx + 1][0]
- if a in dtypes:
- return b
- if b in dtypes:
- return a
- raise RuntimeError("Unexpected termination!")
- def check_pin_memory(pin_memory: bool):
- check(not pin_memory, lambda: "PrimTorch does not support pinned memory", NotImplementedError)
- def check_layout(layout: torch.layout):
- check(layout == torch.strided, lambda: f"PrimTorch doesn't support layout={layout}", NotImplementedError)
- # TODO: maybe unify with can_cast_to?
- def is_weakly_lesser_type(a: type, b: type) -> bool:
- """
- Compares two types, a and b, returning True if a is weakly "less" than b.
- The comparison is determined by the following type ordering: bool, int, float, complex.
- """
- ordered_types = (
- bool,
- int,
- float,
- complex,
- )
- assert a in ordered_types
- assert b in ordered_types
- for typ in ordered_types:
- if a == typ:
- return True
- if b == typ:
- return False
- raise RuntimeError("Unexpected termination!")
- def can_safe_cast_to(*, cast_to: torch.dtype, cast_from: torch.dtype) -> bool:
- for fn in (is_complex_dtype, is_float_dtype, is_integer_dtype, is_boolean_dtype):
- if fn(cast_to):
- return True
- if fn(cast_from):
- return False
- raise ValueError("Received unknown dtypes {0}, {1}!".format(cast_to, cast_from))
- def check_same_dtype(*args):
- """
- Checks that all Tensors in args have the same device and that all Numbers have the
- same corresponding Python type.
- Raises a RuntimeError when:
- - args contains an object whose type is not Tensor or Number
- - two Tensors objects in args have different dtypes
- - two Number objects in args have different types
- - there are Tensors and Numbers in args, and one of those Tensors corresponding
- Python types is different from the type of one of those Numbers
- """
- full_dtype = None
- scalar_type = None
- for arg in args:
- if isinstance(arg, Number):
- # Scalar type checking is disabled (and may be removed in the future)
- continue
- # if scalar_type is None:
- # scalar_type = type(arg)
- # if scalar_type is not type(arg):
- # msg = (
- # "Scalar of type "
- # + str(type(arg))
- # + " is not the expected type of "
- # + str(scalar_type)
- # + "!"
- # )
- # raise RuntimeError(msg)
- elif isinstance(arg, TensorLike):
- if full_dtype is None:
- full_dtype = arg.dtype
- if scalar_type is None:
- scalar_type = dtype_to_type(arg.dtype)
- if full_dtype is not arg.dtype:
- msg = (
- "Tensor with dtype "
- + str(arg.dtype)
- + " is not the expected dtype of "
- + str(full_dtype)
- + "!"
- )
- raise RuntimeError(msg)
- arg_type = dtype_to_type(arg.dtype)
- if arg_type is not scalar_type:
- msg = (
- "Tensor with corresponding Python type "
- + str(arg_type)
- + " is not the expected type of "
- + str(scalar_type)
- + "!"
- )
- raise RuntimeError(msg)
- else:
- msg = (
- "Unexpected type when checking for same dtype, " + str(type(arg)) + "!"
- )
- raise RuntimeError(msg)
- # Maps datatypes to their computation types for elementwise operations
- _computation_dtype_map = {
- torch.bfloat16: torch.float32,
- torch.float16: torch.float32,
- torch.complex32: torch.complex64,
- }
- def get_computation_dtype(dtype: torch.dtype) -> torch.dtype:
- return _computation_dtype_map.get(dtype, dtype)
- _cpu_acc_type_map = {
- torch.bfloat16: torch.float64,
- torch.float16: torch.float64,
- torch.float32: torch.float64,
- torch.complex32: torch.complex128,
- torch.complex64: torch.complex128,
- }
- def get_acc_type(dtype: torch.dtype, device: torch.device) -> torch.dtype:
- # Equivalent to at::toAccumulateType, prefer computation_dtype where possible
- if device.type == "cpu":
- return _cpu_acc_type_map.get(dtype, dtype)
- else:
- return get_computation_dtype(dtype)
- class ELEMENTWISE_TYPE_PROMOTION_KIND(Enum):
- DEFAULT = (0,)
- NO_OPMATH = (1,)
- INT_TO_FLOAT = (2,)
- ALWAYS_BOOL = (3,)
- COMPLEX_TO_FLOAT = (4,)
- BOOL_TO_LONG = (5,)
- class REDUCTION_OUTPUT_TYPE_KIND(Enum):
- SAME = (0,)
- COMPLEX_TO_FLOAT = (1,) # for complex types outputs corresponding real type
- KEEP_PROMOTED_TYPE = (2,) # keep output in opmath type, needed for mean
- ALWAYS_BOOL = (3,)
- # Describes the return type of the primitive:
- #
- # - NEW, a new tensor is created
- # - VIEW, a view of an input tensor is returned
- # - INPLACE, one or more input tensors is modified
- #
- # these descriptors are mututally exclusive and exhaustive.
- class RETURN_TYPE(Enum):
- NEW = (0,)
- VIEW = (1,)
- INPLACE = (2,)
- # TODO: when NumberType contains the sym types, can simplify this
- def number_type(x: Union[NumberType, torch.SymInt, torch.SymFloat]) -> Type:
- if isinstance(x, torch.SymInt):
- return int
- elif isinstance(x, torch.SymFloat):
- return float
- else:
- return type(x)
- # TODO: document type promotion kinds
- def elementwise_dtypes(
- *_args,
- type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND,
- ) -> Tuple[torch.dtype, torch.dtype]:
- """
- Computes the computation and result dtypes for elementwise type promotion
- on the given arguments and with the given elementwise type promotion kind.
- Note that not all inputs to an elementwise operation necessarily participate in type promotion.
- For example, the "alpha" parameter of torch.add does not participate in type promotion,
- although it may be cast to the Python type corresponding to the computation dtype that
- the type promotion algorithm determines.
- Default elementwise type promotion, which all other type promotion kinds tweak (see below),
- first decides which of four ordered types to use:
- bool -> integer -> floating point -> complex
- The selected type is the "lowest" type in the above list such that all number arguments
- have a weakly "lower" type and all tensor arguments have a weakly lower corresponding
- type for their dtype.
- Once the type is determined, the particular result dtype is found. The dtypes are
- partially ordered as follows:
- bool -> uint8, int8 -> int16 -> int32 -> int64 ->
- float16, bfloat16 -> float32 -> float64 -> complex32 -> complex64 -> complex128
- The result dtype is selected by:
- - if no tensor's dtype has the same corresponding type as the one selected,
- then the result dtype is the (default) dtype corresponding to the selected type
- (for example, 1.5 + an integer tensor has a result dtype of the default floating point dtype)
- - if the result type is complex then the dtype is:
- - the default complex dtype if there are no floating point or complex tensors
- - if there are floating point or complex tensors with one or more dimensions, then
- the complex dtype corresponding to the highest corresponding complex dtype among those tensors
- (for example, double + cfloat -> cdouble)
- - if there are only floating point or complex tensors with zero dimensions, then
- the complex dtype corresponding to the highest corresponding complex dtype among those tensors
- - if the first two cases do not apply, the result dtype is the highest dtype among
- all tensors with one or more dimensions of the output type, and if there are no such
- tensors then it's the highest dtype among all tensors with zero dimensions of the output type
- (for example, long + half -> half, even if the half tensor has zero dimensions)
- The "corresponding complex dtypes" are:
- float16 -> complex32
- bfloat16 -> complex64
- float32 -> complex64
- float64 -> complex128
- complex32 -> complex32
- complex64 -> complex64
- complex128 -> complex128
- The DEFAULT type promotion kind computes per above, and then uses the result dtype to pick a computation
- dtype by mapping low precision floating point and complex dtypes as follows:
- float16 -> float32
- bfloat16 -> float32
- complex32 -> complex64
- This is referred to as "op math", and the NO_OPMATH type promotion kind disables this mapping, making the
- computation dtype the same as the result dtype when it's selected. NO_OPMATH is appropriate for kernels
- which perform no mathematical operations on their tensors (see below for examples).
- The INT_TO_FLOAT type promotion kind maps boolean and integer maps result dtypes to the default floating point dtype,
- and computation dtypes to the appropriate op math dtype.
- The COMPLEX_TO_FLOAT type promotion kind maps complex result dtypes to the corresponding float dtype, following this
- mapping:
- complex32 -> float16
- complex64 -> float32
- complex128 -> float64
- Note that COMPLEX_TO_FLOAT derives the computation dtype as the DEFAULT setting does.
- The BOOL_TO_LONG type promotion kind maps boolean computation and result dtypes to long.
- The ALWAYS_BOOL type promotion kind always sets the result dtype to bool.
- Example operators for each type promotion option:
- DEFAULT : add
- NO_OPMATH : where, nextafter, cat
- INT_TO_FLOAT : sin
- COMPLEX_TO_FLOAT : abs
- BOOL_TO_LONG : pow
- ALWAYS_BOOL : eq
- """
- args = tuple(x for x in _args if x is not None)
- highest_type: type = bool
- for x in args:
- if not isinstance(x, (Number, TensorLike)):
- msg = (
- "Unexpected type {0} when computing elementwise type promotion!".format(
- str(type(x))
- )
- )
- raise ValueError(msg)
- if isinstance(x, Number):
- highest_type = get_higher_type(highest_type, number_type(x))
- else:
- # x is a TensorLike
- highest_type = get_higher_type(highest_type, dtype_to_type(x.dtype))
- result_dtype = None
- def _find_highest_dtype_filtered(
- args, filter, *, float_as_complex=False
- ) -> Optional[torch.dtype]:
- zero_dim_tensor_dtype = None
- one_plus_dim_tensor_dtype = None
- for x in args:
- if isinstance(x, TensorLike) and filter(x.dtype):
- _dtype = x.dtype
- if float_as_complex and is_float_dtype(_dtype):
- _dtype = corresponding_complex_dtype(_dtype)
- if x.ndim == 0:
- zero_dim_tensor_dtype = get_higher_dtype(
- zero_dim_tensor_dtype, _dtype
- )
- else:
- # x.ndim > 0
- one_plus_dim_tensor_dtype = get_higher_dtype(
- one_plus_dim_tensor_dtype, _dtype
- )
- # Prefers dtype of tensors with one or more dimensions
- if one_plus_dim_tensor_dtype is not None:
- return one_plus_dim_tensor_dtype
- return zero_dim_tensor_dtype
- if highest_type is float:
- result_dtype = _find_highest_dtype_filtered(args, is_float_dtype)
- result_dtype = (
- torch.get_default_dtype() if result_dtype is None else result_dtype
- )
- elif highest_type is complex:
- result_dtype = _find_highest_dtype_filtered(
- args,
- lambda x: is_float_dtype(x) or is_complex_dtype(x),
- float_as_complex=True,
- )
- if result_dtype is None:
- result_dtype = corresponding_complex_dtype(torch.get_default_dtype())
- elif highest_type is int:
- result_dtype = _find_highest_dtype_filtered(args, is_integer_dtype)
- result_dtype = torch.long if result_dtype is None else result_dtype
- else:
- # highest_type is bool
- result_dtype = torch.bool
- if type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT:
- return get_computation_dtype(result_dtype), result_dtype
- elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH:
- return result_dtype, result_dtype
- elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT:
- if is_integer_dtype(result_dtype) or is_boolean_dtype(result_dtype):
- result_dtype = torch.get_default_dtype()
- return get_computation_dtype(result_dtype), result_dtype
- elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT:
- # NOTE: computation can still occur in a complex dtype
- computation_dtype = get_computation_dtype(result_dtype)
- if is_complex_dtype(result_dtype):
- result_dtype = corresponding_real_dtype(result_dtype)
- return computation_dtype, result_dtype
- elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.BOOL_TO_LONG:
- if is_boolean_dtype(result_dtype):
- return torch.long, torch.long
- return get_computation_dtype(result_dtype), result_dtype
- elif type_promotion_kind is ELEMENTWISE_TYPE_PROMOTION_KIND.ALWAYS_BOOL:
- return get_computation_dtype(result_dtype), torch.bool
- else:
- raise ValueError(
- "Unknown type promotion kind {0}".format(str(type_promotion_kind))
- )
- def reduction_dtypes(
- arg,
- output_dtype_kind: REDUCTION_OUTPUT_TYPE_KIND,
- dtype: Optional[torch.dtype] = None,
- ) -> Tuple[torch.dtype, Optional[torch.dtype]]:
- # even though some reductions, like amin or amax, don't strictly require type promotion,
- # all the math ops (including comparisons) are still defined only for a computation type,
- # so promotion will still happen. We are doing it explicitly here
- inp_dtype = dtype if dtype is not None else arg.dtype
- computation_dtype = get_computation_dtype(inp_dtype)
- if (
- output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.SAME
- or output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
- ):
- result_dtype = dtype if dtype else arg.dtype
- if (
- output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT
- and is_complex_dtype(result_dtype)
- ):
- result_dtype = corresponding_real_dtype(result_dtype)
- elif output_dtype_kind == REDUCTION_OUTPUT_TYPE_KIND.KEEP_PROMOTED_TYPE:
- result_dtype = None
- else: # ALWAYS_BOOL
- result_dtype = torch.bool
- return computation_dtype, result_dtype
- # This function's logic is borrowed from the following functions defined in C++:
- # batched_matrix_contiguous_strides and contiguous_strides
- def make_contiguous_strides_for(
- shape: ShapeType, row_major: bool = True
- ) -> Tuple[int, ...]:
- """
- Returns the strides of a contiguous tensor if row_major
- If row_major=True, it returns the strides of a contiguous batch of Fortran-contiguous matrices
- This is often used when calling external libraries like BLAS/LAPACK/cuSolver...
- """
- # contiguous_strides from c10/util/strides.h
- validate_shape(shape)
- if not shape:
- return ()
- multiplier = 1
- strides = []
- for l in reversed(shape):
- strides.append(multiplier)
- multiplier *= sym_max(l, 1)
- result = tuple(reversed(strides))
- # batched_matrix_contiguous_strides from aten/src/ATen/native/LinearAlgebraUtils.h
- if row_major:
- return result
- else:
- if len(shape) < 2:
- return result
- return result[:-2] + (1, max(shape[-2], 1))
- def make_channels_last_2d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
- # TODO: maybe inform the user of channels_last_3d if rank of the tensor is 5?
- check(
- len(shape) == 4,
- lambda: "Only tensors of rank 4 can use the channels_last memory format",
- )
- multiplier = 1
- strides = [0] * 4
- for idx in (1, -1, -2, 0):
- # NOTE: intentionally divergence from make_contiguous_strides_for
- # This is consistent with eager
- strides[idx] = multiplier
- multiplier *= shape[idx]
- return tuple(strides)
- def make_channels_last_3d_strides_for(shape: ShapeType) -> Tuple[int, ...]:
- check(
- len(shape) == 5,
- lambda: "Only tensors of rank 5 can use the channels_last_3d memory format",
- )
- multiplier = 1
- strides = [0] * 5
- for idx in (1, -1, -2, -3, 0):
- # NOTE: intentionally divergence from make_contiguous_strides_for
- # This is consistent with eager
- strides[idx] = multiplier
- multiplier *= shape[idx]
- return tuple(strides)
- def make_channels_last_strides_for(shape: ShapeType) -> Tuple[int, ...]:
- ndim = len(shape) if isinstance(shape, Sequence) else 1
- if ndim == 4:
- return make_channels_last_2d_strides_for(shape)
- elif ndim == 5:
- return make_channels_last_3d_strides_for(shape)
- else:
- raise RuntimeError(
- f"no channels last format strides exist in {ndim} dimensions"
- )
- def compute_reduction_output_shape(
- shape: ShapeType, dimensions: Sequence
- ) -> Tuple[int, ...]:
- for idx in dimensions:
- validate_idx(len(shape), idx)
- new_shape = []
- for idx in range(len(shape)):
- if idx in dimensions:
- continue
- new_shape.append(shape[idx])
- return tuple(new_shape)
- def validate_no_repeating_dims(dims: Sequence):
- if len(dims) != len(set(dims)):
- raise RuntimeError("duplicate value in the list of dims")
- def reduction_dims(shape: ShapeType, dims: Optional[Sequence]) -> Tuple[int, ...]:
- if dims is None:
- return tuple(range(len(shape)))
- dims = tuple(canonicalize_dim(len(shape), idx) for idx in dims)
- validate_no_repeating_dims(dims)
- return dims
- def set_correction(
- unbiased: Optional[bool] = None,
- correction: Optional[int] = None,
- ):
- if correction is not None and unbiased is not None:
- raise RuntimeError("cannot specify both correction and unbiased arguments")
- elif correction is None and unbiased is None:
- correction = 1
- elif correction is None and unbiased is not None:
- correction = 0 if unbiased is False else 1
- # NB: we don't actually support symint here, but it's harmless to accept
- if not isinstance(correction, IntLike):
- raise ValueError("correction argument should be integer")
- if correction < 0:
- raise ValueError("correction argument should be non-negative")
- return correction
- def compute_required_storage_length(
- shape: ShapeType, strides: StrideType, storage_offset: int
- ) -> int:
- """Computes the minimum storage size to hold the given tensor geometry.
- Example
- =======
- This is the size of a newly allocated tensor's storage, in units of elements
- >>> t = torch.empty((10, 20))
- >>> compute_required_storage_length(t.shape, t.stride(), t.storage_offset())
- 200
- >>> # xdoctest: +SKIP(failing)
- >>> t2 = torch.empty_strided((1, 2, 3), (5, 7, 11))
- >>> size = compute_required_storage_length(t2.shape, t2.stride(), t2.storage_offset())
- >>> size == t.storage().size()
- True
- A valid tensor may have a larger storage size, but never smaller
- >>> slice = torch.empty(100)[20:40]
- >>> slice.storage().size()
- 100
- >>> compute_required_storage_length(slice.shape, slice.stride(), slice.storage_offset())
- 40
- """
- # Short-circuits if the shape has no elements
- if reduce(operator.mul, shape, 1) == 0:
- return 0
- max_offset = sum((x - 1) * y for x, y in zip(shape, strides))
- # +1 to account for the first element which offsets are taken from
- return 1 + storage_offset + max_offset
- def check_in_bounds_for_storage(
- a: torch.TypedStorage, shape: ShapeType, strides: StrideType, storage_offset: int
- ):
- """
- Determines if the given shape, strides, and offset are valid for the given storage.
- """
- required_length = compute_required_storage_length(shape, strides, storage_offset)
- if a.size() < required_length:
- msg = (
- "Can't view a storage of size {0} with an offset of {1}, shape of {2}, and strides of {3}, "
- "which requires a storage of size {4}".format(
- a.size(), storage_offset, str(shape), str(strides), required_length
- )
- )
- raise ValueError(msg)
- def check(
- b: bool, s: Callable[[], str], exc_type: Type[Exception] = RuntimeError
- ) -> None:
- """
- Helper function for raising an error_type (default: RuntimeError) if a boolean condition fails.
- Error message is a callable producing a string (to avoid wasting time
- string formatting in non-error case, and also to make it easier for torchdynamo
- to trace.)
- """
- if not b:
- raise exc_type(s())
- # This combines is_channels_last_strides_2d and is_channels_last_strides_3d in
- # c10/core/MemoryFormat.h into one function
- def are_strides_like_channels_last(
- shape: Sequence[int], strides: Sequence[int]
- ) -> bool:
- ndim = len(shape)
- if ndim == 4:
- # Check for channels_last_2d
- dim_order = [1, 3, 2, 0]
- elif ndim == 5:
- # Check for channels_last_3d
- dim_order = [1, 4, 3, 2, 0]
- else:
- return False
- if strides[1] == 0:
- return False
- min = 0
- for d in dim_order:
- if shape[d] == 0:
- return False
- if strides[d] < min:
- return False
- if d == 0 and min == strides[1]:
- return False
- min = strides[d]
- if strides[d] > 1:
- min *= shape[d]
- return True
- def suggest_memory_format(x: TensorLikeType) -> torch.memory_format:
- if x.layout != torch.strided:
- return torch.contiguous_format
- if are_strides_like_channels_last(x.shape, x.stride()):
- return torch.channels_last if x.ndim == 4 else torch.channels_last_3d
- return torch.contiguous_format
- def prod(xs: Sequence[NumberType]) -> NumberType:
- """Product of elements in input sequence. Returns 1 for empty sequence"""
- return reduce(operator.mul, xs, 1)
- def is_expandable_to(shape: ShapeType, desired: ShapeType) -> bool:
- """Checks if a shape can be expanded to another shape.
- This is equivalent to checking if the two shapes are broadcastable.
- """
- # This is a Python implementation of
- # aten/src/ATen/ExpandUtils.h:is_expandable_to
- if len(shape) > len(desired):
- return False
- for i in range(len(shape)):
- if shape[-i - 1] != desired[-i - 1] and shape[-i - 1] != 1:
- return False
- return True
- def mask_tensor(mask: TensorLikeType, t: TensorLikeType):
- """
- Similar to torch.where(mask, t, 0) but if t is boolean,
- result is also boolean and not promoted to int.
- """
- # torch.where(mask, t, False) is equivalent
- # but feels hacky and might break in the future
- if t.dtype is torch.bool:
- return mask.logical_and(t)
- else:
- return torch.where(mask, t, 0)
- def get_aten_op(fn: Callable, name: str):
- """
- Given the __module__ of reference and its name, it returns
- (our best guess of) the ATen name of the associated operation
- Note: In ATen, the __name__ of a function within a module often
- starts by the module name. E.g. linalg_eigh, or special_zeta
- """
- module = fn.__module__
- prefix = "torch._refs"
- assert(module.startswith(prefix))
- module = module[len(prefix):]
- # We want to go from .special / .nn.functional
- # to special and special_ / nn_functional_
- if module:
- module = module[1:]
- module = module.replace(".", "_")
- module = module + "_"
- return getattr(torch._ops.ops.aten, f"{module}{name}")
- def dtype_or_default(dtype: Optional[torch.dtype]) -> torch.dtype:
- return dtype if dtype is not None else torch.get_default_dtype()
- def device_or_default(device: Optional[torch.device]) -> torch.device:
- return device if device is not None else torch.device("cpu")
- def layout_or_default(layout: Optional[torch.layout]) -> torch.layout:
- return layout if layout is not None else torch.strided
- def clone_preserve_strides(x):
- needed_size = compute_required_storage_length(
- x.size(), x.stride(), x.storage_offset()
- )
- # Our eager implementations for *_scatter ops are all primitives w.r.t autograd,
- # so these as_strided() calls are not seen by autograd.
- # We need to mimic this behavior in our ref/prim implementations.
- # TODO: a better way to handle this would be with a new op, "_unsafe_as_strided"
- # We should revisit this when we add a compositional as_strided op,
- # and also as part of https://github.com/pytorch/pytorch/issues/90507
- try:
- old = torch._C._dispatch_tls_is_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView)
- torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView, True)
- buffer = torch.as_strided(x, (needed_size,), (1,), 0).clone()
- return torch.as_strided(buffer, x.size(), x.stride(), x.storage_offset())
- finally:
- torch._C._dispatch_tls_set_dispatch_key_excluded(torch._C.DispatchKey.ADInplaceOrView, old)
|