123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563 |
- import abc
- import cmath
- import collections.abc
- import contextlib
- import warnings
- from typing import (
- Any,
- Callable,
- Collection,
- Dict,
- List,
- NoReturn,
- Optional,
- Sequence,
- Tuple,
- Type,
- Union,
- )
- import torch
- try:
- import numpy as np
- NUMPY_AVAILABLE = True
- except ModuleNotFoundError:
- NUMPY_AVAILABLE = False
- class ErrorMeta(Exception):
- """Internal testing exception that makes that carries error metadata."""
- def __init__(
- self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = ()
- ) -> None:
- super().__init__(
- "If you are a user and see this message during normal operation "
- "please file an issue at https://github.com/pytorch/pytorch/issues. "
- "If you are a developer and working on the comparison functions, please `raise ErrorMeta().to_error()` "
- "for user facing errors."
- )
- self.type = type
- self.msg = msg
- self.id = id
- def to_error(
- self, msg: Optional[Union[str, Callable[[str], str]]] = None
- ) -> Exception:
- if not isinstance(msg, str):
- generated_msg = self.msg
- if self.id:
- generated_msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}"
- msg = msg(generated_msg) if callable(msg) else generated_msg
- return self.type(msg)
- # Some analysis of tolerance by logging tests from test_torch.py can be found in
- # https://github.com/pytorch/pytorch/pull/32538.
- # {dtype: (rtol, atol)}
- _DTYPE_PRECISIONS = {
- torch.float16: (0.001, 1e-5),
- torch.bfloat16: (0.016, 1e-5),
- torch.float32: (1.3e-6, 1e-5),
- torch.float64: (1e-7, 1e-7),
- torch.complex32: (0.001, 1e-5),
- torch.complex64: (1.3e-6, 1e-5),
- torch.complex128: (1e-7, 1e-7),
- }
- # The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in
- # their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values`
- _DTYPE_PRECISIONS.update(
- {
- dtype: _DTYPE_PRECISIONS[torch.float32]
- for dtype in (
- torch.quint8,
- torch.quint2x4,
- torch.quint4x2,
- torch.qint8,
- torch.qint32,
- )
- }
- )
- def default_tolerances(
- *inputs: Union[torch.Tensor, torch.dtype],
- dtype_precisions: Optional[Dict[torch.dtype, Tuple[float, float]]] = None,
- ) -> Tuple[float, float]:
- """Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype.
- See :func:`assert_close` for a table of the default tolerance for each dtype.
- Returns:
- (Tuple[float, float]): Loosest tolerances of all input dtypes.
- """
- dtypes = []
- for input in inputs:
- if isinstance(input, torch.Tensor):
- dtypes.append(input.dtype)
- elif isinstance(input, torch.dtype):
- dtypes.append(input)
- else:
- raise TypeError(
- f"Expected a torch.Tensor or a torch.dtype, but got {type(input)} instead."
- )
- dtype_precisions = dtype_precisions or _DTYPE_PRECISIONS
- rtols, atols = zip(*[dtype_precisions.get(dtype, (0.0, 0.0)) for dtype in dtypes])
- return max(rtols), max(atols)
- def get_tolerances(
- *inputs: Union[torch.Tensor, torch.dtype],
- rtol: Optional[float],
- atol: Optional[float],
- id: Tuple[Any, ...] = (),
- ) -> Tuple[float, float]:
- """Gets absolute and relative to be used for numeric comparisons.
- If both ``rtol`` and ``atol`` are specified, this is a no-op. If both are not specified, the return value of
- :func:`default_tolerances` is used.
- Raises:
- ErrorMeta: With :class:`ValueError`, if only ``rtol`` or ``atol`` is specified.
- Returns:
- (Tuple[float, float]): Valid absolute and relative tolerances.
- """
- if (rtol is None) ^ (atol is None):
- # We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
- # results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
- raise ErrorMeta(
- ValueError,
- f"Both 'rtol' and 'atol' must be either specified or omitted, "
- f"but got no {'rtol' if rtol is None else 'atol'}.",
- id=id,
- )
- elif rtol is not None and atol is not None:
- return rtol, atol
- else:
- return default_tolerances(*inputs)
- def _make_mismatch_msg(
- *,
- default_identifier: str,
- identifier: Optional[Union[str, Callable[[str], str]]] = None,
- extra: Optional[str] = None,
- abs_diff: float,
- abs_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None,
- atol: float,
- rel_diff: float,
- rel_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None,
- rtol: float,
- ) -> str:
- """Makes a mismatch error message for numeric values.
- Args:
- default_identifier (str): Default description of the compared values, e.g. "Tensor-likes".
- identifier (Optional[Union[str, Callable[[str], str]]]): Optional identifier that overrides
- ``default_identifier``. Can be passed as callable in which case it will be called with
- ``default_identifier`` to create the description at runtime.
- extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics.
- abs_diff (float): Absolute difference.
- abs_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the absolute difference.
- atol (float): Allowed absolute tolerance. Will only be added to mismatch statistics if it or ``rtol`` are
- ``> 0``.
- rel_diff (float): Relative difference.
- rel_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the relative difference.
- rtol (float): Allowed relative tolerance. Will only be added to mismatch statistics if it or ``atol`` are
- ``> 0``.
- """
- equality = rtol == 0 and atol == 0
- def make_diff_msg(
- *,
- type: str,
- diff: float,
- idx: Optional[Union[int, Tuple[int, ...]]],
- tol: float,
- ) -> str:
- if idx is None:
- msg = f"{type.title()} difference: {diff}"
- else:
- msg = f"Greatest {type} difference: {diff} at index {idx}"
- if not equality:
- msg += f" (up to {tol} allowed)"
- return msg + "\n"
- if identifier is None:
- identifier = default_identifier
- elif callable(identifier):
- identifier = identifier(default_identifier)
- msg = f"{identifier} are not {'equal' if equality else 'close'}!\n\n"
- if extra:
- msg += f"{extra.strip()}\n"
- msg += make_diff_msg(type="absolute", diff=abs_diff, idx=abs_diff_idx, tol=atol)
- msg += make_diff_msg(type="relative", diff=rel_diff, idx=rel_diff_idx, tol=rtol)
- return msg.strip()
- def make_scalar_mismatch_msg(
- actual: Union[int, float, complex],
- expected: Union[int, float, complex],
- *,
- rtol: float,
- atol: float,
- identifier: Optional[Union[str, Callable[[str], str]]] = None,
- ) -> str:
- """Makes a mismatch error message for scalars.
- Args:
- actual (Union[int, float, complex]): Actual scalar.
- expected (Union[int, float, complex]): Expected scalar.
- rtol (float): Relative tolerance.
- atol (float): Absolute tolerance.
- identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the scalars. Can be passed
- as callable in which case it will be called by the default value to create the description at runtime.
- Defaults to "Scalars".
- """
- abs_diff = abs(actual - expected)
- rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected)
- return _make_mismatch_msg(
- default_identifier="Scalars",
- identifier=identifier,
- abs_diff=abs_diff,
- atol=atol,
- rel_diff=rel_diff,
- rtol=rtol,
- )
- def make_tensor_mismatch_msg(
- actual: torch.Tensor,
- expected: torch.Tensor,
- mismatches: torch.Tensor,
- *,
- rtol: float,
- atol: float,
- identifier: Optional[Union[str, Callable[[str], str]]] = None,
- ):
- """Makes a mismatch error message for tensors.
- Args:
- actual (torch.Tensor): Actual tensor.
- expected (torch.Tensor): Expected tensor.
- mismatches (torch.Tensor): Boolean mask of the same shape as ``actual`` and ``expected`` that indicates the
- location of mismatches.
- rtol (float): Relative tolerance.
- atol (float): Absolute tolerance.
- identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the tensors. Can be passed
- as callable in which case it will be called by the default value to create the description at runtime.
- Defaults to "Tensor-likes".
- """
- def unravel_flat_index(flat_index: int) -> Tuple[int, ...]:
- if not mismatches.shape:
- return ()
- inverse_index = []
- for size in mismatches.shape[::-1]:
- div, mod = divmod(flat_index, size)
- flat_index = div
- inverse_index.append(mod)
- return tuple(inverse_index[::-1])
- number_of_elements = mismatches.numel()
- total_mismatches = torch.sum(mismatches).item()
- extra = (
- f"Mismatched elements: {total_mismatches} / {number_of_elements} "
- f"({total_mismatches / number_of_elements:.1%})"
- )
- a_flat = actual.flatten()
- b_flat = expected.flatten()
- matches_flat = ~mismatches.flatten()
- abs_diff = torch.abs(a_flat - b_flat)
- # Ensure that only mismatches are used for the max_abs_diff computation
- abs_diff[matches_flat] = 0
- max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0)
- rel_diff = abs_diff / torch.abs(b_flat)
- # Ensure that only mismatches are used for the max_rel_diff computation
- rel_diff[matches_flat] = 0
- max_rel_diff, max_rel_diff_flat_idx = torch.max(rel_diff, 0)
- return _make_mismatch_msg(
- default_identifier="Tensor-likes",
- identifier=identifier,
- extra=extra,
- abs_diff=max_abs_diff.item(),
- abs_diff_idx=unravel_flat_index(int(max_abs_diff_flat_idx)),
- atol=atol,
- rel_diff=max_rel_diff.item(),
- rel_diff_idx=unravel_flat_index(int(max_rel_diff_flat_idx)),
- rtol=rtol,
- )
- class UnsupportedInputs(Exception): # noqa: B903
- """Exception to be raised during the construction of a :class:`Pair` in case it doesn't support the inputs."""
- class Pair(abc.ABC):
- """ABC for all comparison pairs to be used in conjunction with :func:`assert_equal`.
- Each subclass needs to overwrite :meth:`Pair.compare` that performs the actual comparison.
- Each pair receives **all** options, so select the ones applicable for the subclass and forward the rest to the
- super class. Raising an :class:`UnsupportedInputs` during constructions indicates that the pair is not able to
- handle the inputs and the next pair type will be tried.
- All other errors should be raised as :class:`ErrorMeta`. After the instantiation, :meth:`Pair._make_error_meta` can
- be used to automatically handle overwriting the message with a user supplied one and id handling.
- """
- def __init__(
- self,
- actual: Any,
- expected: Any,
- *,
- id: Tuple[Any, ...] = (),
- **unknown_parameters: Any,
- ) -> None:
- self.actual = actual
- self.expected = expected
- self.id = id
- self._unknown_parameters = unknown_parameters
- @staticmethod
- def _inputs_not_supported() -> NoReturn:
- raise UnsupportedInputs()
- @staticmethod
- def _check_inputs_isinstance(*inputs: Any, cls: Union[Type, Tuple[Type, ...]]):
- """Checks if all inputs are instances of a given class and raise :class:`UnsupportedInputs` otherwise."""
- if not all(isinstance(input, cls) for input in inputs):
- Pair._inputs_not_supported()
- def _fail(
- self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = ()
- ) -> NoReturn:
- """Raises an :class:`ErrorMeta` from a given exception type and message and the stored id.
- .. warning::
- If you use this before the ``super().__init__(...)`` call in the constructor, you have to pass the ``id``
- explicitly.
- """
- raise ErrorMeta(type, msg, id=self.id if not id and hasattr(self, "id") else id)
- @abc.abstractmethod
- def compare(self) -> None:
- """Compares the inputs and raises an :class`ErrorMeta` in case they mismatch."""
- def extra_repr(self) -> Sequence[Union[str, Tuple[str, Any]]]:
- """Returns extra information that will be included in the representation.
- Should be overwritten by all subclasses that use additional options. The representation of the object will only
- be surfaced in case we encounter an unexpected error and thus should help debug the issue. Can be a sequence of
- key-value-pairs or attribute names.
- """
- return []
- def __repr__(self) -> str:
- head = f"{type(self).__name__}("
- tail = ")"
- body = [
- f" {name}={value!s},"
- for name, value in [
- ("id", self.id),
- ("actual", self.actual),
- ("expected", self.expected),
- *[
- (extra, getattr(self, extra)) if isinstance(extra, str) else extra
- for extra in self.extra_repr()
- ],
- ]
- ]
- return "\n".join((head, *body, *tail))
- class ObjectPair(Pair):
- """Pair for any type of inputs that will be compared with the `==` operator.
- .. note::
- Since this will instantiate for any kind of inputs, it should only be used as fallback after all other pairs
- couldn't handle the inputs.
- """
- def compare(self) -> None:
- try:
- equal = self.actual == self.expected
- except Exception as error:
- # We are not using `self._raise_error_meta` here since we need the exception chaining
- raise ErrorMeta(
- ValueError,
- f"{self.actual} == {self.expected} failed with:\n{error}.",
- id=self.id,
- ) from error
- if not equal:
- self._fail(AssertionError, f"{self.actual} != {self.expected}")
- class NonePair(Pair):
- """Pair for ``None`` inputs."""
- def __init__(self, actual: Any, expected: Any, **other_parameters: Any) -> None:
- if not (actual is None or expected is None):
- self._inputs_not_supported()
- super().__init__(actual, expected, **other_parameters)
- def compare(self) -> None:
- if not (self.actual is None and self.expected is None):
- self._fail(
- AssertionError, f"None mismatch: {self.actual} is not {self.expected}"
- )
- class BooleanPair(Pair):
- """Pair for :class:`bool` inputs.
- .. note::
- If ``numpy`` is available, also handles :class:`numpy.bool_` inputs.
- """
- def __init__(
- self,
- actual: Any,
- expected: Any,
- *,
- id: Tuple[Any, ...],
- **other_parameters: Any,
- ) -> None:
- actual, expected = self._process_inputs(actual, expected, id=id)
- super().__init__(actual, expected, **other_parameters)
- @property
- def _supported_types(self) -> Tuple[Type, ...]:
- cls: List[Type] = [bool]
- if NUMPY_AVAILABLE:
- cls.append(np.bool_)
- return tuple(cls)
- def _process_inputs(
- self, actual: Any, expected: Any, *, id: Tuple[Any, ...]
- ) -> Tuple[bool, bool]:
- self._check_inputs_isinstance(actual, expected, cls=self._supported_types)
- actual, expected = [
- self._to_bool(bool_like, id=id) for bool_like in (actual, expected)
- ]
- return actual, expected
- def _to_bool(self, bool_like: Any, *, id: Tuple[Any, ...]) -> bool:
- if isinstance(bool_like, bool):
- return bool_like
- elif isinstance(bool_like, np.bool_):
- return bool_like.item()
- else:
- raise ErrorMeta(
- TypeError, f"Unknown boolean type {type(bool_like)}.", id=id
- )
- def compare(self) -> None:
- if self.actual is not self.expected:
- self._fail(
- AssertionError,
- f"Booleans mismatch: {self.actual} is not {self.expected}",
- )
- class NumberPair(Pair):
- """Pair for Python number (:class:`int`, :class:`float`, and :class:`complex`) inputs.
- .. note::
- If ``numpy`` is available, also handles :class:`numpy.number` inputs.
- Kwargs:
- rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default
- values based on the type are selected with the below table.
- atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default
- values based on the type are selected with the below table.
- equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``.
- check_dtype (bool): If ``True``, the type of the inputs will be checked for equality. Defaults to ``False``.
- The following table displays correspondence between Python number type and the ``torch.dtype``'s. See
- :func:`assert_close` for the corresponding tolerances.
- +------------------+-------------------------------+
- | ``type`` | corresponding ``torch.dtype`` |
- +==================+===============================+
- | :class:`int` | :attr:`~torch.int64` |
- +------------------+-------------------------------+
- | :class:`float` | :attr:`~torch.float64` |
- +------------------+-------------------------------+
- | :class:`complex` | :attr:`~torch.complex64` |
- +------------------+-------------------------------+
- """
- _TYPE_TO_DTYPE = {
- int: torch.int64,
- float: torch.float64,
- complex: torch.complex128,
- }
- _NUMBER_TYPES = tuple(_TYPE_TO_DTYPE.keys())
- def __init__(
- self,
- actual: Any,
- expected: Any,
- *,
- id: Tuple[Any, ...] = (),
- rtol: Optional[float] = None,
- atol: Optional[float] = None,
- equal_nan: bool = False,
- check_dtype: bool = False,
- **other_parameters: Any,
- ) -> None:
- actual, expected = self._process_inputs(actual, expected, id=id)
- super().__init__(actual, expected, id=id, **other_parameters)
- self.rtol, self.atol = get_tolerances(
- *[self._TYPE_TO_DTYPE[type(input)] for input in (actual, expected)],
- rtol=rtol,
- atol=atol,
- id=id,
- )
- self.equal_nan = equal_nan
- self.check_dtype = check_dtype
- @property
- def _supported_types(self) -> Tuple[Type, ...]:
- cls = list(self._NUMBER_TYPES)
- if NUMPY_AVAILABLE:
- cls.append(np.number)
- return tuple(cls)
- def _process_inputs(
- self, actual: Any, expected: Any, *, id: Tuple[Any, ...]
- ) -> Tuple[Union[int, float, complex], Union[int, float, complex]]:
- self._check_inputs_isinstance(actual, expected, cls=self._supported_types)
- actual, expected = [
- self._to_number(number_like, id=id) for number_like in (actual, expected)
- ]
- return actual, expected
- def _to_number(
- self, number_like: Any, *, id: Tuple[Any, ...]
- ) -> Union[int, float, complex]:
- if NUMPY_AVAILABLE and isinstance(number_like, np.number):
- return number_like.item()
- elif isinstance(number_like, self._NUMBER_TYPES):
- return number_like
- else:
- raise ErrorMeta(
- TypeError, f"Unknown number type {type(number_like)}.", id=id
- )
- def compare(self) -> None:
- if self.check_dtype and type(self.actual) is not type(self.expected):
- self._fail(
- AssertionError,
- f"The (d)types do not match: {type(self.actual)} != {type(self.expected)}.",
- )
- if self.actual == self.expected:
- return
- if self.equal_nan and cmath.isnan(self.actual) and cmath.isnan(self.expected):
- return
- abs_diff = abs(self.actual - self.expected)
- tolerance = self.atol + self.rtol * abs(self.expected)
- if cmath.isfinite(abs_diff) and abs_diff <= tolerance:
- return
- self._fail(
- AssertionError,
- make_scalar_mismatch_msg(
- self.actual, self.expected, rtol=self.rtol, atol=self.atol
- ),
- )
- def extra_repr(self) -> Sequence[str]:
- return (
- "rtol",
- "atol",
- "equal_nan",
- "check_dtype",
- )
- class TensorLikePair(Pair):
- """Pair for :class:`torch.Tensor`-like inputs.
- Kwargs:
- allow_subclasses (bool):
- rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default
- values based on the type are selected. See :func:assert_close: for details.
- atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default
- values based on the type are selected. See :func:assert_close: for details.
- equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``.
- check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same
- :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different
- :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared.
- check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this
- check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to
- :func:`torch.promote_types`) before being compared.
- check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this
- check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
- compared.
- check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
- """
- def __init__(
- self,
- actual: Any,
- expected: Any,
- *,
- id: Tuple[Any, ...] = (),
- allow_subclasses: bool = True,
- rtol: Optional[float] = None,
- atol: Optional[float] = None,
- equal_nan: bool = False,
- check_device: bool = True,
- check_dtype: bool = True,
- check_layout: bool = True,
- check_stride: bool = False,
- **other_parameters: Any,
- ):
- actual, expected = self._process_inputs(
- actual, expected, id=id, allow_subclasses=allow_subclasses
- )
- super().__init__(actual, expected, id=id, **other_parameters)
- self.rtol, self.atol = get_tolerances(
- actual, expected, rtol=rtol, atol=atol, id=self.id
- )
- self.equal_nan = equal_nan
- self.check_device = check_device
- self.check_dtype = check_dtype
- self.check_layout = check_layout
- self.check_stride = check_stride
- def _process_inputs(
- self, actual: Any, expected: Any, *, id: Tuple[Any, ...], allow_subclasses: bool
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- directly_related = isinstance(actual, type(expected)) or isinstance(
- expected, type(actual)
- )
- if not directly_related:
- self._inputs_not_supported()
- if not allow_subclasses and type(actual) is not type(expected):
- self._inputs_not_supported()
- actual, expected = [self._to_tensor(input) for input in (actual, expected)]
- for tensor in (actual, expected):
- self._check_supported(tensor, id=id)
- return actual, expected
- def _to_tensor(self, tensor_like: Any) -> torch.Tensor:
- if isinstance(tensor_like, torch.Tensor):
- return tensor_like
- try:
- return torch.as_tensor(tensor_like)
- except Exception:
- self._inputs_not_supported()
- def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None:
- if tensor.layout not in {
- torch.strided,
- torch.sparse_coo,
- torch.sparse_csr,
- torch.sparse_csc,
- torch.sparse_bsr,
- torch.sparse_bsc,
- }:
- raise ErrorMeta(
- ValueError, f"Unsupported tensor layout {tensor.layout}", id=id
- )
- def compare(self) -> None:
- actual, expected = self.actual, self.expected
- self._compare_attributes(actual, expected)
- if any(input.device.type == "meta" for input in (actual, expected)):
- return
- actual, expected = self._equalize_attributes(actual, expected)
- self._compare_values(actual, expected)
- def _compare_attributes(
- self,
- actual: torch.Tensor,
- expected: torch.Tensor,
- ) -> None:
- """Checks if the attributes of two tensors match.
- Always checks
- - the :attr:`~torch.Tensor.shape`,
- - whether both inputs are quantized or not,
- - and if they use the same quantization scheme.
- Checks for
- - :attr:`~torch.Tensor.layout`,
- - :meth:`~torch.Tensor.stride`,
- - :attr:`~torch.Tensor.device`, and
- - :attr:`~torch.Tensor.dtype`
- are optional and can be disabled through the corresponding ``check_*`` flag during construction of the pair.
- """
- def raise_mismatch_error(
- attribute_name: str, actual_value: Any, expected_value: Any
- ) -> NoReturn:
- self._fail(
- AssertionError,
- f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.",
- )
- if actual.shape != expected.shape:
- raise_mismatch_error("shape", actual.shape, expected.shape)
- if actual.is_quantized != expected.is_quantized:
- raise_mismatch_error(
- "is_quantized", actual.is_quantized, expected.is_quantized
- )
- elif actual.is_quantized and actual.qscheme() != expected.qscheme():
- raise_mismatch_error("qscheme()", actual.qscheme(), expected.qscheme())
- if actual.layout != expected.layout:
- if self.check_layout:
- raise_mismatch_error("layout", actual.layout, expected.layout)
- elif (
- actual.layout == torch.strided
- and self.check_stride
- and actual.stride() != expected.stride()
- ):
- raise_mismatch_error("stride()", actual.stride(), expected.stride())
- if self.check_device and actual.device != expected.device:
- raise_mismatch_error("device", actual.device, expected.device)
- if self.check_dtype and actual.dtype != expected.dtype:
- raise_mismatch_error("dtype", actual.dtype, expected.dtype)
- def _equalize_attributes(
- self, actual: torch.Tensor, expected: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Equalizes some attributes of two tensors for value comparison.
- If ``actual`` and ``expected`` are ...
- - ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory.
- - ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to
- :func:`torch.promote_types`).
- - ... not of the same ``layout``, they are converted to strided tensors.
- Args:
- actual (Tensor): Actual tensor.
- expected (Tensor): Expected tensor.
- Returns:
- (Tuple[Tensor, Tensor]): Equalized tensors.
- """
- # The comparison logic uses operators currently not supported by the MPS backends.
- # See https://github.com/pytorch/pytorch/issues/77144 for details.
- # TODO: Remove this conversion as soon as all operations are supported natively by the MPS backend
- if actual.is_mps or expected.is_mps: # type: ignore[attr-defined]
- actual = actual.cpu()
- expected = expected.cpu()
- if actual.device != expected.device:
- actual = actual.cpu()
- expected = expected.cpu()
- if actual.dtype != expected.dtype:
- dtype = torch.promote_types(actual.dtype, expected.dtype)
- actual = actual.to(dtype)
- expected = expected.to(dtype)
- if actual.layout != expected.layout:
- # These checks are needed, since Tensor.to_dense() fails on tensors that are already strided
- actual = actual.to_dense() if actual.layout != torch.strided else actual
- expected = (
- expected.to_dense() if expected.layout != torch.strided else expected
- )
- return actual, expected
- def _compare_values(self, actual: torch.Tensor, expected: torch.Tensor) -> None:
- if actual.is_quantized:
- compare_fn = self._compare_quantized_values
- elif actual.is_sparse:
- compare_fn = self._compare_sparse_coo_values
- elif actual.layout in {
- torch.sparse_csr,
- torch.sparse_csc,
- torch.sparse_bsr,
- torch.sparse_bsc,
- }:
- compare_fn = self._compare_sparse_compressed_values
- else:
- compare_fn = self._compare_regular_values_close
- compare_fn(
- actual, expected, rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan
- )
- def _compare_quantized_values(
- self,
- actual: torch.Tensor,
- expected: torch.Tensor,
- *,
- rtol: float,
- atol: float,
- equal_nan: bool,
- ) -> None:
- """Compares quantized tensors by comparing the :meth:`~torch.Tensor.dequantize`'d variants for closeness.
- .. note::
- A detailed discussion about why only the dequantized variant is checked for closeness rather than checking
- the individual quantization parameters for closeness and the integer representation for equality can be
- found in https://github.com/pytorch/pytorch/issues/68548.
- """
- return self._compare_regular_values_close(
- actual.dequantize(),
- expected.dequantize(),
- rtol=rtol,
- atol=atol,
- equal_nan=equal_nan,
- identifier=lambda default_identifier: f"Quantized {default_identifier.lower()}",
- )
- def _compare_sparse_coo_values(
- self,
- actual: torch.Tensor,
- expected: torch.Tensor,
- *,
- rtol: float,
- atol: float,
- equal_nan: bool,
- ) -> None:
- """Compares sparse COO tensors by comparing
- - the number of sparse dimensions,
- - the number of non-zero elements (nnz) for equality,
- - the indices for equality, and
- - the values for closeness.
- """
- if actual.sparse_dim() != expected.sparse_dim():
- self._fail(
- AssertionError,
- (
- f"The number of sparse dimensions in sparse COO tensors does not match: "
- f"{actual.sparse_dim()} != {expected.sparse_dim()}"
- ),
- )
- if actual._nnz() != expected._nnz():
- self._fail(
- AssertionError,
- (
- f"The number of specified values in sparse COO tensors does not match: "
- f"{actual._nnz()} != {expected._nnz()}"
- ),
- )
- self._compare_regular_values_equal(
- actual._indices(),
- expected._indices(),
- identifier="Sparse COO indices",
- )
- self._compare_regular_values_close(
- actual._values(),
- expected._values(),
- rtol=rtol,
- atol=atol,
- equal_nan=equal_nan,
- identifier="Sparse COO values",
- )
- def _compare_sparse_compressed_values(
- self,
- actual: torch.Tensor,
- expected: torch.Tensor,
- *,
- rtol: float,
- atol: float,
- equal_nan: bool,
- ) -> None:
- """Compares sparse compressed tensors by comparing
- - the number of non-zero elements (nnz) for equality,
- - the plain indices for equality,
- - the compressed indices for equality, and
- - the values for closeness.
- """
- format_name, compressed_indices_method, plain_indices_method = {
- torch.sparse_csr: (
- "CSR",
- torch.Tensor.crow_indices,
- torch.Tensor.col_indices,
- ),
- torch.sparse_csc: (
- "CSC",
- torch.Tensor.ccol_indices,
- torch.Tensor.row_indices,
- ),
- torch.sparse_bsr: (
- "BSR",
- torch.Tensor.crow_indices,
- torch.Tensor.col_indices,
- ),
- torch.sparse_bsc: (
- "BSC",
- torch.Tensor.ccol_indices,
- torch.Tensor.row_indices,
- ),
- }[actual.layout]
- if actual._nnz() != expected._nnz():
- self._fail(
- AssertionError,
- (
- f"The number of specified values in sparse {format_name} tensors does not match: "
- f"{actual._nnz()} != {expected._nnz()}"
- ),
- )
- self._compare_regular_values_equal(
- compressed_indices_method(actual),
- compressed_indices_method(expected),
- identifier=f"Sparse {format_name} {compressed_indices_method.__name__}",
- )
- self._compare_regular_values_equal(
- plain_indices_method(actual),
- plain_indices_method(expected),
- identifier=f"Sparse {format_name} {plain_indices_method.__name__}",
- )
- self._compare_regular_values_close(
- actual.values(),
- expected.values(),
- rtol=rtol,
- atol=atol,
- equal_nan=equal_nan,
- identifier=f"Sparse {format_name} values",
- )
- def _compare_regular_values_equal(
- self,
- actual: torch.Tensor,
- expected: torch.Tensor,
- *,
- equal_nan: bool = False,
- identifier: Optional[Union[str, Callable[[str], str]]] = None,
- ) -> None:
- """Checks if the values of two tensors are equal."""
- self._compare_regular_values_close(
- actual, expected, rtol=0, atol=0, equal_nan=equal_nan, identifier=identifier
- )
- def _compare_regular_values_close(
- self,
- actual: torch.Tensor,
- expected: torch.Tensor,
- *,
- rtol: float,
- atol: float,
- equal_nan: bool,
- identifier: Optional[Union[str, Callable[[str], str]]] = None,
- ) -> None:
- """Checks if the values of two tensors are close up to a desired tolerance."""
- actual, expected = self._promote_for_comparison(actual, expected)
- matches = torch.isclose(
- actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan
- )
- if torch.all(matches):
- return
- if actual.shape == torch.Size([]):
- msg = make_scalar_mismatch_msg(
- actual.item(),
- expected.item(),
- rtol=rtol,
- atol=atol,
- identifier=identifier,
- )
- else:
- msg = make_tensor_mismatch_msg(
- actual, expected, ~matches, rtol=rtol, atol=atol, identifier=identifier
- )
- self._fail(AssertionError, msg)
- def _promote_for_comparison(
- self, actual: torch.Tensor, expected: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Promotes the inputs to the comparison dtype based on the input dtype.
- Returns:
- Inputs promoted to the highest precision dtype of the same dtype category. :class:`torch.bool` is treated
- as integral dtype.
- """
- # This is called after self._equalize_attributes() and thus `actual` and `expected` already have the same dtype.
- if actual.dtype.is_complex:
- dtype = torch.complex128
- elif actual.dtype.is_floating_point:
- dtype = torch.float64
- else:
- dtype = torch.int64
- return actual.to(dtype), expected.to(dtype)
- def extra_repr(self) -> Sequence[str]:
- return (
- "rtol",
- "atol",
- "equal_nan",
- "check_device",
- "check_dtype",
- "check_layout",
- "check_stride",
- )
- def originate_pairs(
- actual: Any,
- expected: Any,
- *,
- pair_types: Sequence[Type[Pair]],
- sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,),
- mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,),
- id: Tuple[Any, ...] = (),
- **options: Any,
- ) -> List[Pair]:
- """Originates pairs from the individual inputs.
- ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or
- :class:`~collections.abc.Mapping`'s. In this case the pairs are originated by recursing through them.
- Args:
- actual (Any): Actual input.
- expected (Any): Expected input.
- pair_types (Sequence[Type[Pair]]): Sequence of pair types that will be tried to construct with the inputs.
- First successful pair will be used.
- sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise.
- mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise.
- id (Tuple[Any, ...]): Optional id of a pair that will be included in an error message.
- **options (Any): Options passed to each pair during construction.
- Raises:
- ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Sequence`'s, but their
- length does not match.
- ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Mapping`'s, but their set of
- keys do not match.
- ErrorMeta: With :class`TypeError`, if no pair is able to handle the inputs.
- ErrorMeta: With any expected exception that happens during the construction of a pair.
- Returns:
- (List[Pair]): Originated pairs.
- """
- # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
- # "a" == "a"[0][0]...
- if (
- isinstance(actual, sequence_types)
- and not isinstance(actual, str)
- and isinstance(expected, sequence_types)
- and not isinstance(expected, str)
- ):
- actual_len = len(actual)
- expected_len = len(expected)
- if actual_len != expected_len:
- raise ErrorMeta(
- AssertionError,
- f"The length of the sequences mismatch: {actual_len} != {expected_len}",
- id=id,
- )
- pairs = []
- for idx in range(actual_len):
- pairs.extend(
- originate_pairs(
- actual[idx],
- expected[idx],
- pair_types=pair_types,
- sequence_types=sequence_types,
- mapping_types=mapping_types,
- id=(*id, idx),
- **options,
- )
- )
- return pairs
- elif isinstance(actual, mapping_types) and isinstance(expected, mapping_types):
- actual_keys = set(actual.keys())
- expected_keys = set(expected.keys())
- if actual_keys != expected_keys:
- missing_keys = expected_keys - actual_keys
- additional_keys = actual_keys - expected_keys
- raise ErrorMeta(
- AssertionError,
- (
- f"The keys of the mappings do not match:\n"
- f"Missing keys in the actual mapping: {sorted(missing_keys)}\n"
- f"Additional keys in the actual mapping: {sorted(additional_keys)}"
- ),
- id=id,
- )
- keys: Collection = actual_keys
- # Since the origination aborts after the first failure, we try to be deterministic
- with contextlib.suppress(Exception):
- keys = sorted(keys)
- pairs = []
- for key in keys:
- pairs.extend(
- originate_pairs(
- actual[key],
- expected[key],
- pair_types=pair_types,
- sequence_types=sequence_types,
- mapping_types=mapping_types,
- id=(*id, key),
- **options,
- )
- )
- return pairs
- else:
- for pair_type in pair_types:
- try:
- return [pair_type(actual, expected, id=id, **options)]
- # Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the
- # inputs. Thus, we try the next pair type.
- except UnsupportedInputs:
- continue
- # Raising an `ErrorMeta` during origination is the orderly way to abort and so we simply re-raise it. This
- # is only in a separate branch, because the one below would also except it.
- except ErrorMeta:
- raise
- # Raising any other exception during origination is unexpected and will give some extra information about
- # what happened. If applicable, the exception should be expected in the future.
- except Exception as error:
- raise RuntimeError(
- f"Originating a {pair_type.__name__}() at item {''.join(str([item]) for item in id)} with\n\n"
- f"{type(actual).__name__}(): {actual}\n\n"
- f"and\n\n"
- f"{type(expected).__name__}(): {expected}\n\n"
- f"resulted in the unexpected exception above. "
- f"If you are a user and see this message during normal operation "
- "please file an issue at https://github.com/pytorch/pytorch/issues. "
- "If you are a developer and working on the comparison functions, "
- "please except the previous error and raise an expressive `ErrorMeta` instead."
- ) from error
- else:
- raise ErrorMeta(
- TypeError,
- f"No comparison pair was able to handle inputs of type {type(actual)} and {type(expected)}.",
- id=id,
- )
- def not_close_error_metas(
- actual: Any,
- expected: Any,
- *,
- pair_types: Sequence[Type[Pair]] = (ObjectPair,),
- sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,),
- mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,),
- **options: Any,
- ) -> List[ErrorMeta]:
- """Asserts that inputs are equal.
- ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or
- :class:`~collections.abc.Mapping`'s. In this case the comparison happens elementwise by recursing through them.
- Args:
- actual (Any): Actual input.
- expected (Any): Expected input.
- pair_types (Sequence[Type[Pair]]): Sequence of :class:`Pair` types that will be tried to construct with the
- inputs. First successful pair will be used. Defaults to only using :class:`ObjectPair`.
- sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise.
- mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise.
- **options (Any): Options passed to each pair during construction.
- """
- # Hide this function from `pytest`'s traceback
- __tracebackhide__ = True
- try:
- pairs = originate_pairs(
- actual,
- expected,
- pair_types=pair_types,
- sequence_types=sequence_types,
- mapping_types=mapping_types,
- **options,
- )
- except ErrorMeta as error_meta:
- # Explicitly raising from None to hide the internal traceback
- raise error_meta.to_error() from None
- error_metas: List[ErrorMeta] = []
- for pair in pairs:
- try:
- pair.compare()
- except ErrorMeta as error_meta:
- error_metas.append(error_meta)
- # Raising any exception besides `ErrorMeta` while comparing is unexpected and will give some extra information
- # about what happened. If applicable, the exception should be expected in the future.
- except Exception as error:
- raise RuntimeError(
- f"Comparing\n\n"
- f"{pair}\n\n"
- f"resulted in the unexpected exception above. "
- f"If you are a user and see this message during normal operation "
- "please file an issue at https://github.com/pytorch/pytorch/issues. "
- "If you are a developer and working on the comparison functions, "
- "please except the previous error and raise an expressive `ErrorMeta` instead."
- ) from error
- return error_metas
- def assert_close(
- actual: Any,
- expected: Any,
- *,
- allow_subclasses: bool = True,
- rtol: Optional[float] = None,
- atol: Optional[float] = None,
- equal_nan: bool = False,
- check_device: bool = True,
- check_dtype: bool = True,
- check_layout: bool = True,
- check_stride: bool = False,
- msg: Optional[Union[str, Callable[[str], str]]] = None,
- ):
- r"""Asserts that ``actual`` and ``expected`` are close.
- If ``actual`` and ``expected`` are strided, non-quantized, real-valued, and finite, they are considered close if
- .. math::
- \lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert
- Non-finite values (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are
- only considered equal to each other if ``equal_nan`` is ``True``.
- In addition, they are only considered close if they have the same
- - :attr:`~torch.Tensor.device` (if ``check_device`` is ``True``),
- - ``dtype`` (if ``check_dtype`` is ``True``),
- - ``layout`` (if ``check_layout`` is ``True``), and
- - stride (if ``check_stride`` is ``True``).
- If either ``actual`` or ``expected`` is a meta tensor, only the attribute checks will be performed.
- If ``actual`` and ``expected`` are sparse (either having COO, CSR, CSC, BSR, or BSC layout), their strided members are
- checked individually. Indices, namely ``indices`` for COO, ``crow_indices`` and ``col_indices`` for CSR and BSR,
- or ``ccol_indices`` and ``row_indices`` for CSC and BSC layouts, respectively,
- are always checked for equality whereas the values are checked for closeness according to the definition above.
- If ``actual`` and ``expected`` are quantized, they are considered close if they have the same
- :meth:`~torch.Tensor.qscheme` and the result of :meth:`~torch.Tensor.dequantize` is close according to the
- definition above.
- ``actual`` and ``expected`` can be :class:`~torch.Tensor`'s or any tensor-or-scalar-likes from which
- :class:`torch.Tensor`'s can be constructed with :func:`torch.as_tensor`. Except for Python scalars the input types
- have to be directly related. In addition, ``actual`` and ``expected`` can be :class:`~collections.abc.Sequence`'s
- or :class:`~collections.abc.Mapping`'s in which case they are considered close if their structure matches and all
- their elements are considered close according to the above definition.
- .. note::
- Python scalars are an exception to the type relation requirement, because their :func:`type`, i.e.
- :class:`int`, :class:`float`, and :class:`complex`, is equivalent to the ``dtype`` of a tensor-like. Thus,
- Python scalars of different types can be checked, but require ``check_dtype=False``.
- Args:
- actual (Any): Actual input.
- expected (Any): Expected input.
- allow_subclasses (bool): If ``True`` (default) and except for Python scalars, inputs of directly related types
- are allowed. Otherwise type equality is required.
- rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default
- values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
- atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default
- values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
- equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal.
- check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same
- :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different
- :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared.
- check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this
- check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to
- :func:`torch.promote_types`) before being compared.
- check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this
- check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
- compared.
- check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
- msg (Optional[Union[str, Callable[[str], str]]]): Optional error message to use in case a failure occurs during
- the comparison. Can also passed as callable in which case it will be called with the generated message and
- should return the new message.
- Raises:
- ValueError: If no :class:`torch.Tensor` can be constructed from an input.
- ValueError: If only ``rtol`` or ``atol`` is specified.
- AssertionError: If corresponding inputs are not Python scalars and are not directly related.
- AssertionError: If ``allow_subclasses`` is ``False``, but corresponding inputs are not Python scalars and have
- different types.
- AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
- AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match.
- AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.shape`.
- AssertionError: If ``check_layout`` is ``True``, but corresponding tensors do not have the same
- :attr:`~torch.Tensor.layout`.
- AssertionError: If only one of corresponding tensors is quantized.
- AssertionError: If corresponding tensors are quantized, but have different :meth:`~torch.Tensor.qscheme`'s.
- AssertionError: If ``check_device`` is ``True``, but corresponding tensors are not on the same
- :attr:`~torch.Tensor.device`.
- AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``.
- AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride.
- AssertionError: If the values of corresponding tensors are not close according to the definition above.
- The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching
- ``dtype``'s, the maximum of both tolerances is used.
- +---------------------------+------------+----------+
- | ``dtype`` | ``rtol`` | ``atol`` |
- +===========================+============+==========+
- | :attr:`~torch.float16` | ``1e-3`` | ``1e-5`` |
- +---------------------------+------------+----------+
- | :attr:`~torch.bfloat16` | ``1.6e-2`` | ``1e-5`` |
- +---------------------------+------------+----------+
- | :attr:`~torch.float32` | ``1.3e-6`` | ``1e-5`` |
- +---------------------------+------------+----------+
- | :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` |
- +---------------------------+------------+----------+
- | :attr:`~torch.complex32` | ``1e-3`` | ``1e-5`` |
- +---------------------------+------------+----------+
- | :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` |
- +---------------------------+------------+----------+
- | :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
- +---------------------------+------------+----------+
- | :attr:`~torch.quint8` | ``1.3e-6`` | ``1e-5`` |
- +---------------------------+------------+----------+
- | :attr:`~torch.quint2x4` | ``1.3e-6`` | ``1e-5`` |
- +---------------------------+------------+----------+
- | :attr:`~torch.quint4x2` | ``1.3e-6`` | ``1e-5`` |
- +---------------------------+------------+----------+
- | :attr:`~torch.qint8` | ``1.3e-6`` | ``1e-5`` |
- +---------------------------+------------+----------+
- | :attr:`~torch.qint32` | ``1.3e-6`` | ``1e-5`` |
- +---------------------------+------------+----------+
- | other | ``0.0`` | ``0.0`` |
- +---------------------------+------------+----------+
- .. note::
- :func:`~torch.testing.assert_close` is highly configurable with strict default settings. Users are encouraged
- to :func:`~functools.partial` it to fit their use case. For example, if an equality check is needed, one might
- define an ``assert_equal`` that uses zero tolerances for every ``dtype`` by default:
- >>> import functools
- >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
- >>> assert_equal(1e-9, 1e-10)
- Traceback (most recent call last):
- ...
- AssertionError: Scalars are not equal!
- <BLANKLINE>
- Absolute difference: 9.000000000000001e-10
- Relative difference: 9.0
- Examples:
- >>> # tensor to tensor comparison
- >>> expected = torch.tensor([1e0, 1e-1, 1e-2])
- >>> actual = torch.acos(torch.cos(expected))
- >>> torch.testing.assert_close(actual, expected)
- >>> # scalar to scalar comparison
- >>> import math
- >>> expected = math.sqrt(2.0)
- >>> actual = 2.0 / math.sqrt(2.0)
- >>> torch.testing.assert_close(actual, expected)
- >>> # numpy array to numpy array comparison
- >>> import numpy as np
- >>> expected = np.array([1e0, 1e-1, 1e-2])
- >>> actual = np.arccos(np.cos(expected))
- >>> torch.testing.assert_close(actual, expected)
- >>> # sequence to sequence comparison
- >>> import numpy as np
- >>> # The types of the sequences do not have to match. They only have to have the same
- >>> # length and their elements have to match.
- >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
- >>> actual = tuple(expected)
- >>> torch.testing.assert_close(actual, expected)
- >>> # mapping to mapping comparison
- >>> from collections import OrderedDict
- >>> import numpy as np
- >>> foo = torch.tensor(1.0)
- >>> bar = 2.0
- >>> baz = np.array(3.0)
- >>> # The types and a possible ordering of mappings do not have to match. They only
- >>> # have to have the same set of keys and their elements have to match.
- >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
- >>> actual = {"baz": baz, "bar": bar, "foo": foo}
- >>> torch.testing.assert_close(actual, expected)
- >>> expected = torch.tensor([1.0, 2.0, 3.0])
- >>> actual = expected.clone()
- >>> # By default, directly related instances can be compared
- >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected)
- >>> # This check can be made more strict with allow_subclasses=False
- >>> torch.testing.assert_close(
- ... torch.nn.Parameter(actual), expected, allow_subclasses=False
- ... )
- Traceback (most recent call last):
- ...
- TypeError: No comparison pair was able to handle inputs of type
- <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>.
- >>> # If the inputs are not directly related, they are never considered close
- >>> torch.testing.assert_close(actual.numpy(), expected)
- Traceback (most recent call last):
- ...
- TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'>
- and <class 'torch.Tensor'>.
- >>> # Exceptions to these rules are Python scalars. They can be checked regardless of
- >>> # their type if check_dtype=False.
- >>> torch.testing.assert_close(1.0, 1, check_dtype=False)
- >>> # NaN != NaN by default.
- >>> expected = torch.tensor(float("Nan"))
- >>> actual = expected.clone()
- >>> torch.testing.assert_close(actual, expected)
- Traceback (most recent call last):
- ...
- AssertionError: Scalars are not close!
- <BLANKLINE>
- Absolute difference: nan (up to 1e-05 allowed)
- Relative difference: nan (up to 1.3e-06 allowed)
- >>> torch.testing.assert_close(actual, expected, equal_nan=True)
- >>> expected = torch.tensor([1.0, 2.0, 3.0])
- >>> actual = torch.tensor([1.0, 4.0, 5.0])
- >>> # The default error message can be overwritten.
- >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
- Traceback (most recent call last):
- ...
- AssertionError: Argh, the tensors are not close!
- >>> # If msg is a callable, it can be used to augment the generated message with
- >>> # extra information
- >>> torch.testing.assert_close(
- ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter"
- ... )
- Traceback (most recent call last):
- ...
- AssertionError: Header
- <BLANKLINE>
- Tensor-likes are not close!
- <BLANKLINE>
- Mismatched elements: 2 / 3 (66.7%)
- Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed)
- Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)
- <BLANKLINE>
- Footer
- """
- # Hide this function from `pytest`'s traceback
- __tracebackhide__ = True
- error_metas = not_close_error_metas(
- actual,
- expected,
- pair_types=(
- NonePair,
- BooleanPair,
- NumberPair,
- TensorLikePair,
- ),
- allow_subclasses=allow_subclasses,
- rtol=rtol,
- atol=atol,
- equal_nan=equal_nan,
- check_device=check_device,
- check_dtype=check_dtype,
- check_layout=check_layout,
- check_stride=check_stride,
- msg=msg,
- )
- if error_metas:
- # TODO: compose all metas into one AssertionError
- raise error_metas[0].to_error(msg)
- def assert_allclose(
- actual: Any,
- expected: Any,
- rtol: Optional[float] = None,
- atol: Optional[float] = None,
- equal_nan: bool = True,
- msg: str = "",
- ) -> None:
- """
- .. warning::
- :func:`torch.testing.assert_allclose` is deprecated since ``1.12`` and will be removed in a future release.
- Please use :func:`torch.testing.assert_close` instead. You can find detailed upgrade instructions
- `here <https://github.com/pytorch/pytorch/issues/61844>`_.
- """
- warnings.warn(
- "`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. "
- "Please use `torch.testing.assert_close()` instead. "
- "You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.",
- FutureWarning,
- stacklevel=2,
- )
- if not isinstance(actual, torch.Tensor):
- actual = torch.tensor(actual)
- if not isinstance(expected, torch.Tensor):
- expected = torch.tensor(expected, dtype=actual.dtype)
- if rtol is None and atol is None:
- rtol, atol = default_tolerances(
- actual,
- expected,
- dtype_precisions={
- torch.float16: (1e-3, 1e-3),
- torch.float32: (1e-4, 1e-5),
- torch.float64: (1e-5, 1e-8),
- },
- )
- torch.testing.assert_close(
- actual,
- expected,
- rtol=rtol,
- atol=atol,
- equal_nan=equal_nan,
- check_device=True,
- check_dtype=False,
- check_stride=False,
- msg=msg or None,
- )
|