_comparison.py 60 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563
  1. import abc
  2. import cmath
  3. import collections.abc
  4. import contextlib
  5. import warnings
  6. from typing import (
  7. Any,
  8. Callable,
  9. Collection,
  10. Dict,
  11. List,
  12. NoReturn,
  13. Optional,
  14. Sequence,
  15. Tuple,
  16. Type,
  17. Union,
  18. )
  19. import torch
  20. try:
  21. import numpy as np
  22. NUMPY_AVAILABLE = True
  23. except ModuleNotFoundError:
  24. NUMPY_AVAILABLE = False
  25. class ErrorMeta(Exception):
  26. """Internal testing exception that makes that carries error metadata."""
  27. def __init__(
  28. self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = ()
  29. ) -> None:
  30. super().__init__(
  31. "If you are a user and see this message during normal operation "
  32. "please file an issue at https://github.com/pytorch/pytorch/issues. "
  33. "If you are a developer and working on the comparison functions, please `raise ErrorMeta().to_error()` "
  34. "for user facing errors."
  35. )
  36. self.type = type
  37. self.msg = msg
  38. self.id = id
  39. def to_error(
  40. self, msg: Optional[Union[str, Callable[[str], str]]] = None
  41. ) -> Exception:
  42. if not isinstance(msg, str):
  43. generated_msg = self.msg
  44. if self.id:
  45. generated_msg += f"\n\nThe failure occurred for item {''.join(str([item]) for item in self.id)}"
  46. msg = msg(generated_msg) if callable(msg) else generated_msg
  47. return self.type(msg)
  48. # Some analysis of tolerance by logging tests from test_torch.py can be found in
  49. # https://github.com/pytorch/pytorch/pull/32538.
  50. # {dtype: (rtol, atol)}
  51. _DTYPE_PRECISIONS = {
  52. torch.float16: (0.001, 1e-5),
  53. torch.bfloat16: (0.016, 1e-5),
  54. torch.float32: (1.3e-6, 1e-5),
  55. torch.float64: (1e-7, 1e-7),
  56. torch.complex32: (0.001, 1e-5),
  57. torch.complex64: (1.3e-6, 1e-5),
  58. torch.complex128: (1e-7, 1e-7),
  59. }
  60. # The default tolerances of torch.float32 are used for quantized dtypes, because quantized tensors are compared in
  61. # their dequantized and floating point representation. For more details see `TensorLikePair._compare_quantized_values`
  62. _DTYPE_PRECISIONS.update(
  63. {
  64. dtype: _DTYPE_PRECISIONS[torch.float32]
  65. for dtype in (
  66. torch.quint8,
  67. torch.quint2x4,
  68. torch.quint4x2,
  69. torch.qint8,
  70. torch.qint32,
  71. )
  72. }
  73. )
  74. def default_tolerances(
  75. *inputs: Union[torch.Tensor, torch.dtype],
  76. dtype_precisions: Optional[Dict[torch.dtype, Tuple[float, float]]] = None,
  77. ) -> Tuple[float, float]:
  78. """Returns the default absolute and relative testing tolerances for a set of inputs based on the dtype.
  79. See :func:`assert_close` for a table of the default tolerance for each dtype.
  80. Returns:
  81. (Tuple[float, float]): Loosest tolerances of all input dtypes.
  82. """
  83. dtypes = []
  84. for input in inputs:
  85. if isinstance(input, torch.Tensor):
  86. dtypes.append(input.dtype)
  87. elif isinstance(input, torch.dtype):
  88. dtypes.append(input)
  89. else:
  90. raise TypeError(
  91. f"Expected a torch.Tensor or a torch.dtype, but got {type(input)} instead."
  92. )
  93. dtype_precisions = dtype_precisions or _DTYPE_PRECISIONS
  94. rtols, atols = zip(*[dtype_precisions.get(dtype, (0.0, 0.0)) for dtype in dtypes])
  95. return max(rtols), max(atols)
  96. def get_tolerances(
  97. *inputs: Union[torch.Tensor, torch.dtype],
  98. rtol: Optional[float],
  99. atol: Optional[float],
  100. id: Tuple[Any, ...] = (),
  101. ) -> Tuple[float, float]:
  102. """Gets absolute and relative to be used for numeric comparisons.
  103. If both ``rtol`` and ``atol`` are specified, this is a no-op. If both are not specified, the return value of
  104. :func:`default_tolerances` is used.
  105. Raises:
  106. ErrorMeta: With :class:`ValueError`, if only ``rtol`` or ``atol`` is specified.
  107. Returns:
  108. (Tuple[float, float]): Valid absolute and relative tolerances.
  109. """
  110. if (rtol is None) ^ (atol is None):
  111. # We require both tolerance to be omitted or specified, because specifying only one might lead to surprising
  112. # results. Imagine setting atol=0.0 and the tensors still match because rtol>0.0.
  113. raise ErrorMeta(
  114. ValueError,
  115. f"Both 'rtol' and 'atol' must be either specified or omitted, "
  116. f"but got no {'rtol' if rtol is None else 'atol'}.",
  117. id=id,
  118. )
  119. elif rtol is not None and atol is not None:
  120. return rtol, atol
  121. else:
  122. return default_tolerances(*inputs)
  123. def _make_mismatch_msg(
  124. *,
  125. default_identifier: str,
  126. identifier: Optional[Union[str, Callable[[str], str]]] = None,
  127. extra: Optional[str] = None,
  128. abs_diff: float,
  129. abs_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None,
  130. atol: float,
  131. rel_diff: float,
  132. rel_diff_idx: Optional[Union[int, Tuple[int, ...]]] = None,
  133. rtol: float,
  134. ) -> str:
  135. """Makes a mismatch error message for numeric values.
  136. Args:
  137. default_identifier (str): Default description of the compared values, e.g. "Tensor-likes".
  138. identifier (Optional[Union[str, Callable[[str], str]]]): Optional identifier that overrides
  139. ``default_identifier``. Can be passed as callable in which case it will be called with
  140. ``default_identifier`` to create the description at runtime.
  141. extra (Optional[str]): Extra information to be placed after the message header and the mismatch statistics.
  142. abs_diff (float): Absolute difference.
  143. abs_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the absolute difference.
  144. atol (float): Allowed absolute tolerance. Will only be added to mismatch statistics if it or ``rtol`` are
  145. ``> 0``.
  146. rel_diff (float): Relative difference.
  147. rel_diff_idx (Optional[Union[int, Tuple[int, ...]]]): Optional index of the relative difference.
  148. rtol (float): Allowed relative tolerance. Will only be added to mismatch statistics if it or ``atol`` are
  149. ``> 0``.
  150. """
  151. equality = rtol == 0 and atol == 0
  152. def make_diff_msg(
  153. *,
  154. type: str,
  155. diff: float,
  156. idx: Optional[Union[int, Tuple[int, ...]]],
  157. tol: float,
  158. ) -> str:
  159. if idx is None:
  160. msg = f"{type.title()} difference: {diff}"
  161. else:
  162. msg = f"Greatest {type} difference: {diff} at index {idx}"
  163. if not equality:
  164. msg += f" (up to {tol} allowed)"
  165. return msg + "\n"
  166. if identifier is None:
  167. identifier = default_identifier
  168. elif callable(identifier):
  169. identifier = identifier(default_identifier)
  170. msg = f"{identifier} are not {'equal' if equality else 'close'}!\n\n"
  171. if extra:
  172. msg += f"{extra.strip()}\n"
  173. msg += make_diff_msg(type="absolute", diff=abs_diff, idx=abs_diff_idx, tol=atol)
  174. msg += make_diff_msg(type="relative", diff=rel_diff, idx=rel_diff_idx, tol=rtol)
  175. return msg.strip()
  176. def make_scalar_mismatch_msg(
  177. actual: Union[int, float, complex],
  178. expected: Union[int, float, complex],
  179. *,
  180. rtol: float,
  181. atol: float,
  182. identifier: Optional[Union[str, Callable[[str], str]]] = None,
  183. ) -> str:
  184. """Makes a mismatch error message for scalars.
  185. Args:
  186. actual (Union[int, float, complex]): Actual scalar.
  187. expected (Union[int, float, complex]): Expected scalar.
  188. rtol (float): Relative tolerance.
  189. atol (float): Absolute tolerance.
  190. identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the scalars. Can be passed
  191. as callable in which case it will be called by the default value to create the description at runtime.
  192. Defaults to "Scalars".
  193. """
  194. abs_diff = abs(actual - expected)
  195. rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected)
  196. return _make_mismatch_msg(
  197. default_identifier="Scalars",
  198. identifier=identifier,
  199. abs_diff=abs_diff,
  200. atol=atol,
  201. rel_diff=rel_diff,
  202. rtol=rtol,
  203. )
  204. def make_tensor_mismatch_msg(
  205. actual: torch.Tensor,
  206. expected: torch.Tensor,
  207. mismatches: torch.Tensor,
  208. *,
  209. rtol: float,
  210. atol: float,
  211. identifier: Optional[Union[str, Callable[[str], str]]] = None,
  212. ):
  213. """Makes a mismatch error message for tensors.
  214. Args:
  215. actual (torch.Tensor): Actual tensor.
  216. expected (torch.Tensor): Expected tensor.
  217. mismatches (torch.Tensor): Boolean mask of the same shape as ``actual`` and ``expected`` that indicates the
  218. location of mismatches.
  219. rtol (float): Relative tolerance.
  220. atol (float): Absolute tolerance.
  221. identifier (Optional[Union[str, Callable[[str], str]]]): Optional description for the tensors. Can be passed
  222. as callable in which case it will be called by the default value to create the description at runtime.
  223. Defaults to "Tensor-likes".
  224. """
  225. def unravel_flat_index(flat_index: int) -> Tuple[int, ...]:
  226. if not mismatches.shape:
  227. return ()
  228. inverse_index = []
  229. for size in mismatches.shape[::-1]:
  230. div, mod = divmod(flat_index, size)
  231. flat_index = div
  232. inverse_index.append(mod)
  233. return tuple(inverse_index[::-1])
  234. number_of_elements = mismatches.numel()
  235. total_mismatches = torch.sum(mismatches).item()
  236. extra = (
  237. f"Mismatched elements: {total_mismatches} / {number_of_elements} "
  238. f"({total_mismatches / number_of_elements:.1%})"
  239. )
  240. a_flat = actual.flatten()
  241. b_flat = expected.flatten()
  242. matches_flat = ~mismatches.flatten()
  243. abs_diff = torch.abs(a_flat - b_flat)
  244. # Ensure that only mismatches are used for the max_abs_diff computation
  245. abs_diff[matches_flat] = 0
  246. max_abs_diff, max_abs_diff_flat_idx = torch.max(abs_diff, 0)
  247. rel_diff = abs_diff / torch.abs(b_flat)
  248. # Ensure that only mismatches are used for the max_rel_diff computation
  249. rel_diff[matches_flat] = 0
  250. max_rel_diff, max_rel_diff_flat_idx = torch.max(rel_diff, 0)
  251. return _make_mismatch_msg(
  252. default_identifier="Tensor-likes",
  253. identifier=identifier,
  254. extra=extra,
  255. abs_diff=max_abs_diff.item(),
  256. abs_diff_idx=unravel_flat_index(int(max_abs_diff_flat_idx)),
  257. atol=atol,
  258. rel_diff=max_rel_diff.item(),
  259. rel_diff_idx=unravel_flat_index(int(max_rel_diff_flat_idx)),
  260. rtol=rtol,
  261. )
  262. class UnsupportedInputs(Exception): # noqa: B903
  263. """Exception to be raised during the construction of a :class:`Pair` in case it doesn't support the inputs."""
  264. class Pair(abc.ABC):
  265. """ABC for all comparison pairs to be used in conjunction with :func:`assert_equal`.
  266. Each subclass needs to overwrite :meth:`Pair.compare` that performs the actual comparison.
  267. Each pair receives **all** options, so select the ones applicable for the subclass and forward the rest to the
  268. super class. Raising an :class:`UnsupportedInputs` during constructions indicates that the pair is not able to
  269. handle the inputs and the next pair type will be tried.
  270. All other errors should be raised as :class:`ErrorMeta`. After the instantiation, :meth:`Pair._make_error_meta` can
  271. be used to automatically handle overwriting the message with a user supplied one and id handling.
  272. """
  273. def __init__(
  274. self,
  275. actual: Any,
  276. expected: Any,
  277. *,
  278. id: Tuple[Any, ...] = (),
  279. **unknown_parameters: Any,
  280. ) -> None:
  281. self.actual = actual
  282. self.expected = expected
  283. self.id = id
  284. self._unknown_parameters = unknown_parameters
  285. @staticmethod
  286. def _inputs_not_supported() -> NoReturn:
  287. raise UnsupportedInputs()
  288. @staticmethod
  289. def _check_inputs_isinstance(*inputs: Any, cls: Union[Type, Tuple[Type, ...]]):
  290. """Checks if all inputs are instances of a given class and raise :class:`UnsupportedInputs` otherwise."""
  291. if not all(isinstance(input, cls) for input in inputs):
  292. Pair._inputs_not_supported()
  293. def _fail(
  294. self, type: Type[Exception], msg: str, *, id: Tuple[Any, ...] = ()
  295. ) -> NoReturn:
  296. """Raises an :class:`ErrorMeta` from a given exception type and message and the stored id.
  297. .. warning::
  298. If you use this before the ``super().__init__(...)`` call in the constructor, you have to pass the ``id``
  299. explicitly.
  300. """
  301. raise ErrorMeta(type, msg, id=self.id if not id and hasattr(self, "id") else id)
  302. @abc.abstractmethod
  303. def compare(self) -> None:
  304. """Compares the inputs and raises an :class`ErrorMeta` in case they mismatch."""
  305. def extra_repr(self) -> Sequence[Union[str, Tuple[str, Any]]]:
  306. """Returns extra information that will be included in the representation.
  307. Should be overwritten by all subclasses that use additional options. The representation of the object will only
  308. be surfaced in case we encounter an unexpected error and thus should help debug the issue. Can be a sequence of
  309. key-value-pairs or attribute names.
  310. """
  311. return []
  312. def __repr__(self) -> str:
  313. head = f"{type(self).__name__}("
  314. tail = ")"
  315. body = [
  316. f" {name}={value!s},"
  317. for name, value in [
  318. ("id", self.id),
  319. ("actual", self.actual),
  320. ("expected", self.expected),
  321. *[
  322. (extra, getattr(self, extra)) if isinstance(extra, str) else extra
  323. for extra in self.extra_repr()
  324. ],
  325. ]
  326. ]
  327. return "\n".join((head, *body, *tail))
  328. class ObjectPair(Pair):
  329. """Pair for any type of inputs that will be compared with the `==` operator.
  330. .. note::
  331. Since this will instantiate for any kind of inputs, it should only be used as fallback after all other pairs
  332. couldn't handle the inputs.
  333. """
  334. def compare(self) -> None:
  335. try:
  336. equal = self.actual == self.expected
  337. except Exception as error:
  338. # We are not using `self._raise_error_meta` here since we need the exception chaining
  339. raise ErrorMeta(
  340. ValueError,
  341. f"{self.actual} == {self.expected} failed with:\n{error}.",
  342. id=self.id,
  343. ) from error
  344. if not equal:
  345. self._fail(AssertionError, f"{self.actual} != {self.expected}")
  346. class NonePair(Pair):
  347. """Pair for ``None`` inputs."""
  348. def __init__(self, actual: Any, expected: Any, **other_parameters: Any) -> None:
  349. if not (actual is None or expected is None):
  350. self._inputs_not_supported()
  351. super().__init__(actual, expected, **other_parameters)
  352. def compare(self) -> None:
  353. if not (self.actual is None and self.expected is None):
  354. self._fail(
  355. AssertionError, f"None mismatch: {self.actual} is not {self.expected}"
  356. )
  357. class BooleanPair(Pair):
  358. """Pair for :class:`bool` inputs.
  359. .. note::
  360. If ``numpy`` is available, also handles :class:`numpy.bool_` inputs.
  361. """
  362. def __init__(
  363. self,
  364. actual: Any,
  365. expected: Any,
  366. *,
  367. id: Tuple[Any, ...],
  368. **other_parameters: Any,
  369. ) -> None:
  370. actual, expected = self._process_inputs(actual, expected, id=id)
  371. super().__init__(actual, expected, **other_parameters)
  372. @property
  373. def _supported_types(self) -> Tuple[Type, ...]:
  374. cls: List[Type] = [bool]
  375. if NUMPY_AVAILABLE:
  376. cls.append(np.bool_)
  377. return tuple(cls)
  378. def _process_inputs(
  379. self, actual: Any, expected: Any, *, id: Tuple[Any, ...]
  380. ) -> Tuple[bool, bool]:
  381. self._check_inputs_isinstance(actual, expected, cls=self._supported_types)
  382. actual, expected = [
  383. self._to_bool(bool_like, id=id) for bool_like in (actual, expected)
  384. ]
  385. return actual, expected
  386. def _to_bool(self, bool_like: Any, *, id: Tuple[Any, ...]) -> bool:
  387. if isinstance(bool_like, bool):
  388. return bool_like
  389. elif isinstance(bool_like, np.bool_):
  390. return bool_like.item()
  391. else:
  392. raise ErrorMeta(
  393. TypeError, f"Unknown boolean type {type(bool_like)}.", id=id
  394. )
  395. def compare(self) -> None:
  396. if self.actual is not self.expected:
  397. self._fail(
  398. AssertionError,
  399. f"Booleans mismatch: {self.actual} is not {self.expected}",
  400. )
  401. class NumberPair(Pair):
  402. """Pair for Python number (:class:`int`, :class:`float`, and :class:`complex`) inputs.
  403. .. note::
  404. If ``numpy`` is available, also handles :class:`numpy.number` inputs.
  405. Kwargs:
  406. rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default
  407. values based on the type are selected with the below table.
  408. atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default
  409. values based on the type are selected with the below table.
  410. equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``.
  411. check_dtype (bool): If ``True``, the type of the inputs will be checked for equality. Defaults to ``False``.
  412. The following table displays correspondence between Python number type and the ``torch.dtype``'s. See
  413. :func:`assert_close` for the corresponding tolerances.
  414. +------------------+-------------------------------+
  415. | ``type`` | corresponding ``torch.dtype`` |
  416. +==================+===============================+
  417. | :class:`int` | :attr:`~torch.int64` |
  418. +------------------+-------------------------------+
  419. | :class:`float` | :attr:`~torch.float64` |
  420. +------------------+-------------------------------+
  421. | :class:`complex` | :attr:`~torch.complex64` |
  422. +------------------+-------------------------------+
  423. """
  424. _TYPE_TO_DTYPE = {
  425. int: torch.int64,
  426. float: torch.float64,
  427. complex: torch.complex128,
  428. }
  429. _NUMBER_TYPES = tuple(_TYPE_TO_DTYPE.keys())
  430. def __init__(
  431. self,
  432. actual: Any,
  433. expected: Any,
  434. *,
  435. id: Tuple[Any, ...] = (),
  436. rtol: Optional[float] = None,
  437. atol: Optional[float] = None,
  438. equal_nan: bool = False,
  439. check_dtype: bool = False,
  440. **other_parameters: Any,
  441. ) -> None:
  442. actual, expected = self._process_inputs(actual, expected, id=id)
  443. super().__init__(actual, expected, id=id, **other_parameters)
  444. self.rtol, self.atol = get_tolerances(
  445. *[self._TYPE_TO_DTYPE[type(input)] for input in (actual, expected)],
  446. rtol=rtol,
  447. atol=atol,
  448. id=id,
  449. )
  450. self.equal_nan = equal_nan
  451. self.check_dtype = check_dtype
  452. @property
  453. def _supported_types(self) -> Tuple[Type, ...]:
  454. cls = list(self._NUMBER_TYPES)
  455. if NUMPY_AVAILABLE:
  456. cls.append(np.number)
  457. return tuple(cls)
  458. def _process_inputs(
  459. self, actual: Any, expected: Any, *, id: Tuple[Any, ...]
  460. ) -> Tuple[Union[int, float, complex], Union[int, float, complex]]:
  461. self._check_inputs_isinstance(actual, expected, cls=self._supported_types)
  462. actual, expected = [
  463. self._to_number(number_like, id=id) for number_like in (actual, expected)
  464. ]
  465. return actual, expected
  466. def _to_number(
  467. self, number_like: Any, *, id: Tuple[Any, ...]
  468. ) -> Union[int, float, complex]:
  469. if NUMPY_AVAILABLE and isinstance(number_like, np.number):
  470. return number_like.item()
  471. elif isinstance(number_like, self._NUMBER_TYPES):
  472. return number_like
  473. else:
  474. raise ErrorMeta(
  475. TypeError, f"Unknown number type {type(number_like)}.", id=id
  476. )
  477. def compare(self) -> None:
  478. if self.check_dtype and type(self.actual) is not type(self.expected):
  479. self._fail(
  480. AssertionError,
  481. f"The (d)types do not match: {type(self.actual)} != {type(self.expected)}.",
  482. )
  483. if self.actual == self.expected:
  484. return
  485. if self.equal_nan and cmath.isnan(self.actual) and cmath.isnan(self.expected):
  486. return
  487. abs_diff = abs(self.actual - self.expected)
  488. tolerance = self.atol + self.rtol * abs(self.expected)
  489. if cmath.isfinite(abs_diff) and abs_diff <= tolerance:
  490. return
  491. self._fail(
  492. AssertionError,
  493. make_scalar_mismatch_msg(
  494. self.actual, self.expected, rtol=self.rtol, atol=self.atol
  495. ),
  496. )
  497. def extra_repr(self) -> Sequence[str]:
  498. return (
  499. "rtol",
  500. "atol",
  501. "equal_nan",
  502. "check_dtype",
  503. )
  504. class TensorLikePair(Pair):
  505. """Pair for :class:`torch.Tensor`-like inputs.
  506. Kwargs:
  507. allow_subclasses (bool):
  508. rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default
  509. values based on the type are selected. See :func:assert_close: for details.
  510. atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default
  511. values based on the type are selected. See :func:assert_close: for details.
  512. equal_nan (bool): If ``True``, two ``NaN`` values are considered equal. Defaults to ``False``.
  513. check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same
  514. :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different
  515. :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared.
  516. check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this
  517. check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to
  518. :func:`torch.promote_types`) before being compared.
  519. check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this
  520. check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
  521. compared.
  522. check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
  523. """
  524. def __init__(
  525. self,
  526. actual: Any,
  527. expected: Any,
  528. *,
  529. id: Tuple[Any, ...] = (),
  530. allow_subclasses: bool = True,
  531. rtol: Optional[float] = None,
  532. atol: Optional[float] = None,
  533. equal_nan: bool = False,
  534. check_device: bool = True,
  535. check_dtype: bool = True,
  536. check_layout: bool = True,
  537. check_stride: bool = False,
  538. **other_parameters: Any,
  539. ):
  540. actual, expected = self._process_inputs(
  541. actual, expected, id=id, allow_subclasses=allow_subclasses
  542. )
  543. super().__init__(actual, expected, id=id, **other_parameters)
  544. self.rtol, self.atol = get_tolerances(
  545. actual, expected, rtol=rtol, atol=atol, id=self.id
  546. )
  547. self.equal_nan = equal_nan
  548. self.check_device = check_device
  549. self.check_dtype = check_dtype
  550. self.check_layout = check_layout
  551. self.check_stride = check_stride
  552. def _process_inputs(
  553. self, actual: Any, expected: Any, *, id: Tuple[Any, ...], allow_subclasses: bool
  554. ) -> Tuple[torch.Tensor, torch.Tensor]:
  555. directly_related = isinstance(actual, type(expected)) or isinstance(
  556. expected, type(actual)
  557. )
  558. if not directly_related:
  559. self._inputs_not_supported()
  560. if not allow_subclasses and type(actual) is not type(expected):
  561. self._inputs_not_supported()
  562. actual, expected = [self._to_tensor(input) for input in (actual, expected)]
  563. for tensor in (actual, expected):
  564. self._check_supported(tensor, id=id)
  565. return actual, expected
  566. def _to_tensor(self, tensor_like: Any) -> torch.Tensor:
  567. if isinstance(tensor_like, torch.Tensor):
  568. return tensor_like
  569. try:
  570. return torch.as_tensor(tensor_like)
  571. except Exception:
  572. self._inputs_not_supported()
  573. def _check_supported(self, tensor: torch.Tensor, *, id: Tuple[Any, ...]) -> None:
  574. if tensor.layout not in {
  575. torch.strided,
  576. torch.sparse_coo,
  577. torch.sparse_csr,
  578. torch.sparse_csc,
  579. torch.sparse_bsr,
  580. torch.sparse_bsc,
  581. }:
  582. raise ErrorMeta(
  583. ValueError, f"Unsupported tensor layout {tensor.layout}", id=id
  584. )
  585. def compare(self) -> None:
  586. actual, expected = self.actual, self.expected
  587. self._compare_attributes(actual, expected)
  588. if any(input.device.type == "meta" for input in (actual, expected)):
  589. return
  590. actual, expected = self._equalize_attributes(actual, expected)
  591. self._compare_values(actual, expected)
  592. def _compare_attributes(
  593. self,
  594. actual: torch.Tensor,
  595. expected: torch.Tensor,
  596. ) -> None:
  597. """Checks if the attributes of two tensors match.
  598. Always checks
  599. - the :attr:`~torch.Tensor.shape`,
  600. - whether both inputs are quantized or not,
  601. - and if they use the same quantization scheme.
  602. Checks for
  603. - :attr:`~torch.Tensor.layout`,
  604. - :meth:`~torch.Tensor.stride`,
  605. - :attr:`~torch.Tensor.device`, and
  606. - :attr:`~torch.Tensor.dtype`
  607. are optional and can be disabled through the corresponding ``check_*`` flag during construction of the pair.
  608. """
  609. def raise_mismatch_error(
  610. attribute_name: str, actual_value: Any, expected_value: Any
  611. ) -> NoReturn:
  612. self._fail(
  613. AssertionError,
  614. f"The values for attribute '{attribute_name}' do not match: {actual_value} != {expected_value}.",
  615. )
  616. if actual.shape != expected.shape:
  617. raise_mismatch_error("shape", actual.shape, expected.shape)
  618. if actual.is_quantized != expected.is_quantized:
  619. raise_mismatch_error(
  620. "is_quantized", actual.is_quantized, expected.is_quantized
  621. )
  622. elif actual.is_quantized and actual.qscheme() != expected.qscheme():
  623. raise_mismatch_error("qscheme()", actual.qscheme(), expected.qscheme())
  624. if actual.layout != expected.layout:
  625. if self.check_layout:
  626. raise_mismatch_error("layout", actual.layout, expected.layout)
  627. elif (
  628. actual.layout == torch.strided
  629. and self.check_stride
  630. and actual.stride() != expected.stride()
  631. ):
  632. raise_mismatch_error("stride()", actual.stride(), expected.stride())
  633. if self.check_device and actual.device != expected.device:
  634. raise_mismatch_error("device", actual.device, expected.device)
  635. if self.check_dtype and actual.dtype != expected.dtype:
  636. raise_mismatch_error("dtype", actual.dtype, expected.dtype)
  637. def _equalize_attributes(
  638. self, actual: torch.Tensor, expected: torch.Tensor
  639. ) -> Tuple[torch.Tensor, torch.Tensor]:
  640. """Equalizes some attributes of two tensors for value comparison.
  641. If ``actual`` and ``expected`` are ...
  642. - ... not on the same :attr:`~torch.Tensor.device`, they are moved CPU memory.
  643. - ... not of the same ``dtype``, they are promoted to a common ``dtype`` (according to
  644. :func:`torch.promote_types`).
  645. - ... not of the same ``layout``, they are converted to strided tensors.
  646. Args:
  647. actual (Tensor): Actual tensor.
  648. expected (Tensor): Expected tensor.
  649. Returns:
  650. (Tuple[Tensor, Tensor]): Equalized tensors.
  651. """
  652. # The comparison logic uses operators currently not supported by the MPS backends.
  653. # See https://github.com/pytorch/pytorch/issues/77144 for details.
  654. # TODO: Remove this conversion as soon as all operations are supported natively by the MPS backend
  655. if actual.is_mps or expected.is_mps: # type: ignore[attr-defined]
  656. actual = actual.cpu()
  657. expected = expected.cpu()
  658. if actual.device != expected.device:
  659. actual = actual.cpu()
  660. expected = expected.cpu()
  661. if actual.dtype != expected.dtype:
  662. dtype = torch.promote_types(actual.dtype, expected.dtype)
  663. actual = actual.to(dtype)
  664. expected = expected.to(dtype)
  665. if actual.layout != expected.layout:
  666. # These checks are needed, since Tensor.to_dense() fails on tensors that are already strided
  667. actual = actual.to_dense() if actual.layout != torch.strided else actual
  668. expected = (
  669. expected.to_dense() if expected.layout != torch.strided else expected
  670. )
  671. return actual, expected
  672. def _compare_values(self, actual: torch.Tensor, expected: torch.Tensor) -> None:
  673. if actual.is_quantized:
  674. compare_fn = self._compare_quantized_values
  675. elif actual.is_sparse:
  676. compare_fn = self._compare_sparse_coo_values
  677. elif actual.layout in {
  678. torch.sparse_csr,
  679. torch.sparse_csc,
  680. torch.sparse_bsr,
  681. torch.sparse_bsc,
  682. }:
  683. compare_fn = self._compare_sparse_compressed_values
  684. else:
  685. compare_fn = self._compare_regular_values_close
  686. compare_fn(
  687. actual, expected, rtol=self.rtol, atol=self.atol, equal_nan=self.equal_nan
  688. )
  689. def _compare_quantized_values(
  690. self,
  691. actual: torch.Tensor,
  692. expected: torch.Tensor,
  693. *,
  694. rtol: float,
  695. atol: float,
  696. equal_nan: bool,
  697. ) -> None:
  698. """Compares quantized tensors by comparing the :meth:`~torch.Tensor.dequantize`'d variants for closeness.
  699. .. note::
  700. A detailed discussion about why only the dequantized variant is checked for closeness rather than checking
  701. the individual quantization parameters for closeness and the integer representation for equality can be
  702. found in https://github.com/pytorch/pytorch/issues/68548.
  703. """
  704. return self._compare_regular_values_close(
  705. actual.dequantize(),
  706. expected.dequantize(),
  707. rtol=rtol,
  708. atol=atol,
  709. equal_nan=equal_nan,
  710. identifier=lambda default_identifier: f"Quantized {default_identifier.lower()}",
  711. )
  712. def _compare_sparse_coo_values(
  713. self,
  714. actual: torch.Tensor,
  715. expected: torch.Tensor,
  716. *,
  717. rtol: float,
  718. atol: float,
  719. equal_nan: bool,
  720. ) -> None:
  721. """Compares sparse COO tensors by comparing
  722. - the number of sparse dimensions,
  723. - the number of non-zero elements (nnz) for equality,
  724. - the indices for equality, and
  725. - the values for closeness.
  726. """
  727. if actual.sparse_dim() != expected.sparse_dim():
  728. self._fail(
  729. AssertionError,
  730. (
  731. f"The number of sparse dimensions in sparse COO tensors does not match: "
  732. f"{actual.sparse_dim()} != {expected.sparse_dim()}"
  733. ),
  734. )
  735. if actual._nnz() != expected._nnz():
  736. self._fail(
  737. AssertionError,
  738. (
  739. f"The number of specified values in sparse COO tensors does not match: "
  740. f"{actual._nnz()} != {expected._nnz()}"
  741. ),
  742. )
  743. self._compare_regular_values_equal(
  744. actual._indices(),
  745. expected._indices(),
  746. identifier="Sparse COO indices",
  747. )
  748. self._compare_regular_values_close(
  749. actual._values(),
  750. expected._values(),
  751. rtol=rtol,
  752. atol=atol,
  753. equal_nan=equal_nan,
  754. identifier="Sparse COO values",
  755. )
  756. def _compare_sparse_compressed_values(
  757. self,
  758. actual: torch.Tensor,
  759. expected: torch.Tensor,
  760. *,
  761. rtol: float,
  762. atol: float,
  763. equal_nan: bool,
  764. ) -> None:
  765. """Compares sparse compressed tensors by comparing
  766. - the number of non-zero elements (nnz) for equality,
  767. - the plain indices for equality,
  768. - the compressed indices for equality, and
  769. - the values for closeness.
  770. """
  771. format_name, compressed_indices_method, plain_indices_method = {
  772. torch.sparse_csr: (
  773. "CSR",
  774. torch.Tensor.crow_indices,
  775. torch.Tensor.col_indices,
  776. ),
  777. torch.sparse_csc: (
  778. "CSC",
  779. torch.Tensor.ccol_indices,
  780. torch.Tensor.row_indices,
  781. ),
  782. torch.sparse_bsr: (
  783. "BSR",
  784. torch.Tensor.crow_indices,
  785. torch.Tensor.col_indices,
  786. ),
  787. torch.sparse_bsc: (
  788. "BSC",
  789. torch.Tensor.ccol_indices,
  790. torch.Tensor.row_indices,
  791. ),
  792. }[actual.layout]
  793. if actual._nnz() != expected._nnz():
  794. self._fail(
  795. AssertionError,
  796. (
  797. f"The number of specified values in sparse {format_name} tensors does not match: "
  798. f"{actual._nnz()} != {expected._nnz()}"
  799. ),
  800. )
  801. self._compare_regular_values_equal(
  802. compressed_indices_method(actual),
  803. compressed_indices_method(expected),
  804. identifier=f"Sparse {format_name} {compressed_indices_method.__name__}",
  805. )
  806. self._compare_regular_values_equal(
  807. plain_indices_method(actual),
  808. plain_indices_method(expected),
  809. identifier=f"Sparse {format_name} {plain_indices_method.__name__}",
  810. )
  811. self._compare_regular_values_close(
  812. actual.values(),
  813. expected.values(),
  814. rtol=rtol,
  815. atol=atol,
  816. equal_nan=equal_nan,
  817. identifier=f"Sparse {format_name} values",
  818. )
  819. def _compare_regular_values_equal(
  820. self,
  821. actual: torch.Tensor,
  822. expected: torch.Tensor,
  823. *,
  824. equal_nan: bool = False,
  825. identifier: Optional[Union[str, Callable[[str], str]]] = None,
  826. ) -> None:
  827. """Checks if the values of two tensors are equal."""
  828. self._compare_regular_values_close(
  829. actual, expected, rtol=0, atol=0, equal_nan=equal_nan, identifier=identifier
  830. )
  831. def _compare_regular_values_close(
  832. self,
  833. actual: torch.Tensor,
  834. expected: torch.Tensor,
  835. *,
  836. rtol: float,
  837. atol: float,
  838. equal_nan: bool,
  839. identifier: Optional[Union[str, Callable[[str], str]]] = None,
  840. ) -> None:
  841. """Checks if the values of two tensors are close up to a desired tolerance."""
  842. actual, expected = self._promote_for_comparison(actual, expected)
  843. matches = torch.isclose(
  844. actual, expected, rtol=rtol, atol=atol, equal_nan=equal_nan
  845. )
  846. if torch.all(matches):
  847. return
  848. if actual.shape == torch.Size([]):
  849. msg = make_scalar_mismatch_msg(
  850. actual.item(),
  851. expected.item(),
  852. rtol=rtol,
  853. atol=atol,
  854. identifier=identifier,
  855. )
  856. else:
  857. msg = make_tensor_mismatch_msg(
  858. actual, expected, ~matches, rtol=rtol, atol=atol, identifier=identifier
  859. )
  860. self._fail(AssertionError, msg)
  861. def _promote_for_comparison(
  862. self, actual: torch.Tensor, expected: torch.Tensor
  863. ) -> Tuple[torch.Tensor, torch.Tensor]:
  864. """Promotes the inputs to the comparison dtype based on the input dtype.
  865. Returns:
  866. Inputs promoted to the highest precision dtype of the same dtype category. :class:`torch.bool` is treated
  867. as integral dtype.
  868. """
  869. # This is called after self._equalize_attributes() and thus `actual` and `expected` already have the same dtype.
  870. if actual.dtype.is_complex:
  871. dtype = torch.complex128
  872. elif actual.dtype.is_floating_point:
  873. dtype = torch.float64
  874. else:
  875. dtype = torch.int64
  876. return actual.to(dtype), expected.to(dtype)
  877. def extra_repr(self) -> Sequence[str]:
  878. return (
  879. "rtol",
  880. "atol",
  881. "equal_nan",
  882. "check_device",
  883. "check_dtype",
  884. "check_layout",
  885. "check_stride",
  886. )
  887. def originate_pairs(
  888. actual: Any,
  889. expected: Any,
  890. *,
  891. pair_types: Sequence[Type[Pair]],
  892. sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,),
  893. mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,),
  894. id: Tuple[Any, ...] = (),
  895. **options: Any,
  896. ) -> List[Pair]:
  897. """Originates pairs from the individual inputs.
  898. ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or
  899. :class:`~collections.abc.Mapping`'s. In this case the pairs are originated by recursing through them.
  900. Args:
  901. actual (Any): Actual input.
  902. expected (Any): Expected input.
  903. pair_types (Sequence[Type[Pair]]): Sequence of pair types that will be tried to construct with the inputs.
  904. First successful pair will be used.
  905. sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise.
  906. mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise.
  907. id (Tuple[Any, ...]): Optional id of a pair that will be included in an error message.
  908. **options (Any): Options passed to each pair during construction.
  909. Raises:
  910. ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Sequence`'s, but their
  911. length does not match.
  912. ErrorMeta: With :class`AssertionError`, if the inputs are :class:`~collections.abc.Mapping`'s, but their set of
  913. keys do not match.
  914. ErrorMeta: With :class`TypeError`, if no pair is able to handle the inputs.
  915. ErrorMeta: With any expected exception that happens during the construction of a pair.
  916. Returns:
  917. (List[Pair]): Originated pairs.
  918. """
  919. # We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
  920. # "a" == "a"[0][0]...
  921. if (
  922. isinstance(actual, sequence_types)
  923. and not isinstance(actual, str)
  924. and isinstance(expected, sequence_types)
  925. and not isinstance(expected, str)
  926. ):
  927. actual_len = len(actual)
  928. expected_len = len(expected)
  929. if actual_len != expected_len:
  930. raise ErrorMeta(
  931. AssertionError,
  932. f"The length of the sequences mismatch: {actual_len} != {expected_len}",
  933. id=id,
  934. )
  935. pairs = []
  936. for idx in range(actual_len):
  937. pairs.extend(
  938. originate_pairs(
  939. actual[idx],
  940. expected[idx],
  941. pair_types=pair_types,
  942. sequence_types=sequence_types,
  943. mapping_types=mapping_types,
  944. id=(*id, idx),
  945. **options,
  946. )
  947. )
  948. return pairs
  949. elif isinstance(actual, mapping_types) and isinstance(expected, mapping_types):
  950. actual_keys = set(actual.keys())
  951. expected_keys = set(expected.keys())
  952. if actual_keys != expected_keys:
  953. missing_keys = expected_keys - actual_keys
  954. additional_keys = actual_keys - expected_keys
  955. raise ErrorMeta(
  956. AssertionError,
  957. (
  958. f"The keys of the mappings do not match:\n"
  959. f"Missing keys in the actual mapping: {sorted(missing_keys)}\n"
  960. f"Additional keys in the actual mapping: {sorted(additional_keys)}"
  961. ),
  962. id=id,
  963. )
  964. keys: Collection = actual_keys
  965. # Since the origination aborts after the first failure, we try to be deterministic
  966. with contextlib.suppress(Exception):
  967. keys = sorted(keys)
  968. pairs = []
  969. for key in keys:
  970. pairs.extend(
  971. originate_pairs(
  972. actual[key],
  973. expected[key],
  974. pair_types=pair_types,
  975. sequence_types=sequence_types,
  976. mapping_types=mapping_types,
  977. id=(*id, key),
  978. **options,
  979. )
  980. )
  981. return pairs
  982. else:
  983. for pair_type in pair_types:
  984. try:
  985. return [pair_type(actual, expected, id=id, **options)]
  986. # Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the
  987. # inputs. Thus, we try the next pair type.
  988. except UnsupportedInputs:
  989. continue
  990. # Raising an `ErrorMeta` during origination is the orderly way to abort and so we simply re-raise it. This
  991. # is only in a separate branch, because the one below would also except it.
  992. except ErrorMeta:
  993. raise
  994. # Raising any other exception during origination is unexpected and will give some extra information about
  995. # what happened. If applicable, the exception should be expected in the future.
  996. except Exception as error:
  997. raise RuntimeError(
  998. f"Originating a {pair_type.__name__}() at item {''.join(str([item]) for item in id)} with\n\n"
  999. f"{type(actual).__name__}(): {actual}\n\n"
  1000. f"and\n\n"
  1001. f"{type(expected).__name__}(): {expected}\n\n"
  1002. f"resulted in the unexpected exception above. "
  1003. f"If you are a user and see this message during normal operation "
  1004. "please file an issue at https://github.com/pytorch/pytorch/issues. "
  1005. "If you are a developer and working on the comparison functions, "
  1006. "please except the previous error and raise an expressive `ErrorMeta` instead."
  1007. ) from error
  1008. else:
  1009. raise ErrorMeta(
  1010. TypeError,
  1011. f"No comparison pair was able to handle inputs of type {type(actual)} and {type(expected)}.",
  1012. id=id,
  1013. )
  1014. def not_close_error_metas(
  1015. actual: Any,
  1016. expected: Any,
  1017. *,
  1018. pair_types: Sequence[Type[Pair]] = (ObjectPair,),
  1019. sequence_types: Tuple[Type, ...] = (collections.abc.Sequence,),
  1020. mapping_types: Tuple[Type, ...] = (collections.abc.Mapping,),
  1021. **options: Any,
  1022. ) -> List[ErrorMeta]:
  1023. """Asserts that inputs are equal.
  1024. ``actual`` and ``expected`` can be possibly nested :class:`~collections.abc.Sequence`'s or
  1025. :class:`~collections.abc.Mapping`'s. In this case the comparison happens elementwise by recursing through them.
  1026. Args:
  1027. actual (Any): Actual input.
  1028. expected (Any): Expected input.
  1029. pair_types (Sequence[Type[Pair]]): Sequence of :class:`Pair` types that will be tried to construct with the
  1030. inputs. First successful pair will be used. Defaults to only using :class:`ObjectPair`.
  1031. sequence_types (Tuple[Type, ...]): Optional types treated as sequences that will be checked elementwise.
  1032. mapping_types (Tuple[Type, ...]): Optional types treated as mappings that will be checked elementwise.
  1033. **options (Any): Options passed to each pair during construction.
  1034. """
  1035. # Hide this function from `pytest`'s traceback
  1036. __tracebackhide__ = True
  1037. try:
  1038. pairs = originate_pairs(
  1039. actual,
  1040. expected,
  1041. pair_types=pair_types,
  1042. sequence_types=sequence_types,
  1043. mapping_types=mapping_types,
  1044. **options,
  1045. )
  1046. except ErrorMeta as error_meta:
  1047. # Explicitly raising from None to hide the internal traceback
  1048. raise error_meta.to_error() from None
  1049. error_metas: List[ErrorMeta] = []
  1050. for pair in pairs:
  1051. try:
  1052. pair.compare()
  1053. except ErrorMeta as error_meta:
  1054. error_metas.append(error_meta)
  1055. # Raising any exception besides `ErrorMeta` while comparing is unexpected and will give some extra information
  1056. # about what happened. If applicable, the exception should be expected in the future.
  1057. except Exception as error:
  1058. raise RuntimeError(
  1059. f"Comparing\n\n"
  1060. f"{pair}\n\n"
  1061. f"resulted in the unexpected exception above. "
  1062. f"If you are a user and see this message during normal operation "
  1063. "please file an issue at https://github.com/pytorch/pytorch/issues. "
  1064. "If you are a developer and working on the comparison functions, "
  1065. "please except the previous error and raise an expressive `ErrorMeta` instead."
  1066. ) from error
  1067. return error_metas
  1068. def assert_close(
  1069. actual: Any,
  1070. expected: Any,
  1071. *,
  1072. allow_subclasses: bool = True,
  1073. rtol: Optional[float] = None,
  1074. atol: Optional[float] = None,
  1075. equal_nan: bool = False,
  1076. check_device: bool = True,
  1077. check_dtype: bool = True,
  1078. check_layout: bool = True,
  1079. check_stride: bool = False,
  1080. msg: Optional[Union[str, Callable[[str], str]]] = None,
  1081. ):
  1082. r"""Asserts that ``actual`` and ``expected`` are close.
  1083. If ``actual`` and ``expected`` are strided, non-quantized, real-valued, and finite, they are considered close if
  1084. .. math::
  1085. \lvert \text{actual} - \text{expected} \rvert \le \texttt{atol} + \texttt{rtol} \cdot \lvert \text{expected} \rvert
  1086. Non-finite values (``-inf`` and ``inf``) are only considered close if and only if they are equal. ``NaN``'s are
  1087. only considered equal to each other if ``equal_nan`` is ``True``.
  1088. In addition, they are only considered close if they have the same
  1089. - :attr:`~torch.Tensor.device` (if ``check_device`` is ``True``),
  1090. - ``dtype`` (if ``check_dtype`` is ``True``),
  1091. - ``layout`` (if ``check_layout`` is ``True``), and
  1092. - stride (if ``check_stride`` is ``True``).
  1093. If either ``actual`` or ``expected`` is a meta tensor, only the attribute checks will be performed.
  1094. If ``actual`` and ``expected`` are sparse (either having COO, CSR, CSC, BSR, or BSC layout), their strided members are
  1095. checked individually. Indices, namely ``indices`` for COO, ``crow_indices`` and ``col_indices`` for CSR and BSR,
  1096. or ``ccol_indices`` and ``row_indices`` for CSC and BSC layouts, respectively,
  1097. are always checked for equality whereas the values are checked for closeness according to the definition above.
  1098. If ``actual`` and ``expected`` are quantized, they are considered close if they have the same
  1099. :meth:`~torch.Tensor.qscheme` and the result of :meth:`~torch.Tensor.dequantize` is close according to the
  1100. definition above.
  1101. ``actual`` and ``expected`` can be :class:`~torch.Tensor`'s or any tensor-or-scalar-likes from which
  1102. :class:`torch.Tensor`'s can be constructed with :func:`torch.as_tensor`. Except for Python scalars the input types
  1103. have to be directly related. In addition, ``actual`` and ``expected`` can be :class:`~collections.abc.Sequence`'s
  1104. or :class:`~collections.abc.Mapping`'s in which case they are considered close if their structure matches and all
  1105. their elements are considered close according to the above definition.
  1106. .. note::
  1107. Python scalars are an exception to the type relation requirement, because their :func:`type`, i.e.
  1108. :class:`int`, :class:`float`, and :class:`complex`, is equivalent to the ``dtype`` of a tensor-like. Thus,
  1109. Python scalars of different types can be checked, but require ``check_dtype=False``.
  1110. Args:
  1111. actual (Any): Actual input.
  1112. expected (Any): Expected input.
  1113. allow_subclasses (bool): If ``True`` (default) and except for Python scalars, inputs of directly related types
  1114. are allowed. Otherwise type equality is required.
  1115. rtol (Optional[float]): Relative tolerance. If specified ``atol`` must also be specified. If omitted, default
  1116. values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
  1117. atol (Optional[float]): Absolute tolerance. If specified ``rtol`` must also be specified. If omitted, default
  1118. values based on the :attr:`~torch.Tensor.dtype` are selected with the below table.
  1119. equal_nan (Union[bool, str]): If ``True``, two ``NaN`` values will be considered equal.
  1120. check_device (bool): If ``True`` (default), asserts that corresponding tensors are on the same
  1121. :attr:`~torch.Tensor.device`. If this check is disabled, tensors on different
  1122. :attr:`~torch.Tensor.device`'s are moved to the CPU before being compared.
  1123. check_dtype (bool): If ``True`` (default), asserts that corresponding tensors have the same ``dtype``. If this
  1124. check is disabled, tensors with different ``dtype``'s are promoted to a common ``dtype`` (according to
  1125. :func:`torch.promote_types`) before being compared.
  1126. check_layout (bool): If ``True`` (default), asserts that corresponding tensors have the same ``layout``. If this
  1127. check is disabled, tensors with different ``layout``'s are converted to strided tensors before being
  1128. compared.
  1129. check_stride (bool): If ``True`` and corresponding tensors are strided, asserts that they have the same stride.
  1130. msg (Optional[Union[str, Callable[[str], str]]]): Optional error message to use in case a failure occurs during
  1131. the comparison. Can also passed as callable in which case it will be called with the generated message and
  1132. should return the new message.
  1133. Raises:
  1134. ValueError: If no :class:`torch.Tensor` can be constructed from an input.
  1135. ValueError: If only ``rtol`` or ``atol`` is specified.
  1136. AssertionError: If corresponding inputs are not Python scalars and are not directly related.
  1137. AssertionError: If ``allow_subclasses`` is ``False``, but corresponding inputs are not Python scalars and have
  1138. different types.
  1139. AssertionError: If the inputs are :class:`~collections.abc.Sequence`'s, but their length does not match.
  1140. AssertionError: If the inputs are :class:`~collections.abc.Mapping`'s, but their set of keys do not match.
  1141. AssertionError: If corresponding tensors do not have the same :attr:`~torch.Tensor.shape`.
  1142. AssertionError: If ``check_layout`` is ``True``, but corresponding tensors do not have the same
  1143. :attr:`~torch.Tensor.layout`.
  1144. AssertionError: If only one of corresponding tensors is quantized.
  1145. AssertionError: If corresponding tensors are quantized, but have different :meth:`~torch.Tensor.qscheme`'s.
  1146. AssertionError: If ``check_device`` is ``True``, but corresponding tensors are not on the same
  1147. :attr:`~torch.Tensor.device`.
  1148. AssertionError: If ``check_dtype`` is ``True``, but corresponding tensors do not have the same ``dtype``.
  1149. AssertionError: If ``check_stride`` is ``True``, but corresponding strided tensors do not have the same stride.
  1150. AssertionError: If the values of corresponding tensors are not close according to the definition above.
  1151. The following table displays the default ``rtol`` and ``atol`` for different ``dtype``'s. In case of mismatching
  1152. ``dtype``'s, the maximum of both tolerances is used.
  1153. +---------------------------+------------+----------+
  1154. | ``dtype`` | ``rtol`` | ``atol`` |
  1155. +===========================+============+==========+
  1156. | :attr:`~torch.float16` | ``1e-3`` | ``1e-5`` |
  1157. +---------------------------+------------+----------+
  1158. | :attr:`~torch.bfloat16` | ``1.6e-2`` | ``1e-5`` |
  1159. +---------------------------+------------+----------+
  1160. | :attr:`~torch.float32` | ``1.3e-6`` | ``1e-5`` |
  1161. +---------------------------+------------+----------+
  1162. | :attr:`~torch.float64` | ``1e-7`` | ``1e-7`` |
  1163. +---------------------------+------------+----------+
  1164. | :attr:`~torch.complex32` | ``1e-3`` | ``1e-5`` |
  1165. +---------------------------+------------+----------+
  1166. | :attr:`~torch.complex64` | ``1.3e-6`` | ``1e-5`` |
  1167. +---------------------------+------------+----------+
  1168. | :attr:`~torch.complex128` | ``1e-7`` | ``1e-7`` |
  1169. +---------------------------+------------+----------+
  1170. | :attr:`~torch.quint8` | ``1.3e-6`` | ``1e-5`` |
  1171. +---------------------------+------------+----------+
  1172. | :attr:`~torch.quint2x4` | ``1.3e-6`` | ``1e-5`` |
  1173. +---------------------------+------------+----------+
  1174. | :attr:`~torch.quint4x2` | ``1.3e-6`` | ``1e-5`` |
  1175. +---------------------------+------------+----------+
  1176. | :attr:`~torch.qint8` | ``1.3e-6`` | ``1e-5`` |
  1177. +---------------------------+------------+----------+
  1178. | :attr:`~torch.qint32` | ``1.3e-6`` | ``1e-5`` |
  1179. +---------------------------+------------+----------+
  1180. | other | ``0.0`` | ``0.0`` |
  1181. +---------------------------+------------+----------+
  1182. .. note::
  1183. :func:`~torch.testing.assert_close` is highly configurable with strict default settings. Users are encouraged
  1184. to :func:`~functools.partial` it to fit their use case. For example, if an equality check is needed, one might
  1185. define an ``assert_equal`` that uses zero tolerances for every ``dtype`` by default:
  1186. >>> import functools
  1187. >>> assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
  1188. >>> assert_equal(1e-9, 1e-10)
  1189. Traceback (most recent call last):
  1190. ...
  1191. AssertionError: Scalars are not equal!
  1192. <BLANKLINE>
  1193. Absolute difference: 9.000000000000001e-10
  1194. Relative difference: 9.0
  1195. Examples:
  1196. >>> # tensor to tensor comparison
  1197. >>> expected = torch.tensor([1e0, 1e-1, 1e-2])
  1198. >>> actual = torch.acos(torch.cos(expected))
  1199. >>> torch.testing.assert_close(actual, expected)
  1200. >>> # scalar to scalar comparison
  1201. >>> import math
  1202. >>> expected = math.sqrt(2.0)
  1203. >>> actual = 2.0 / math.sqrt(2.0)
  1204. >>> torch.testing.assert_close(actual, expected)
  1205. >>> # numpy array to numpy array comparison
  1206. >>> import numpy as np
  1207. >>> expected = np.array([1e0, 1e-1, 1e-2])
  1208. >>> actual = np.arccos(np.cos(expected))
  1209. >>> torch.testing.assert_close(actual, expected)
  1210. >>> # sequence to sequence comparison
  1211. >>> import numpy as np
  1212. >>> # The types of the sequences do not have to match. They only have to have the same
  1213. >>> # length and their elements have to match.
  1214. >>> expected = [torch.tensor([1.0]), 2.0, np.array(3.0)]
  1215. >>> actual = tuple(expected)
  1216. >>> torch.testing.assert_close(actual, expected)
  1217. >>> # mapping to mapping comparison
  1218. >>> from collections import OrderedDict
  1219. >>> import numpy as np
  1220. >>> foo = torch.tensor(1.0)
  1221. >>> bar = 2.0
  1222. >>> baz = np.array(3.0)
  1223. >>> # The types and a possible ordering of mappings do not have to match. They only
  1224. >>> # have to have the same set of keys and their elements have to match.
  1225. >>> expected = OrderedDict([("foo", foo), ("bar", bar), ("baz", baz)])
  1226. >>> actual = {"baz": baz, "bar": bar, "foo": foo}
  1227. >>> torch.testing.assert_close(actual, expected)
  1228. >>> expected = torch.tensor([1.0, 2.0, 3.0])
  1229. >>> actual = expected.clone()
  1230. >>> # By default, directly related instances can be compared
  1231. >>> torch.testing.assert_close(torch.nn.Parameter(actual), expected)
  1232. >>> # This check can be made more strict with allow_subclasses=False
  1233. >>> torch.testing.assert_close(
  1234. ... torch.nn.Parameter(actual), expected, allow_subclasses=False
  1235. ... )
  1236. Traceback (most recent call last):
  1237. ...
  1238. TypeError: No comparison pair was able to handle inputs of type
  1239. <class 'torch.nn.parameter.Parameter'> and <class 'torch.Tensor'>.
  1240. >>> # If the inputs are not directly related, they are never considered close
  1241. >>> torch.testing.assert_close(actual.numpy(), expected)
  1242. Traceback (most recent call last):
  1243. ...
  1244. TypeError: No comparison pair was able to handle inputs of type <class 'numpy.ndarray'>
  1245. and <class 'torch.Tensor'>.
  1246. >>> # Exceptions to these rules are Python scalars. They can be checked regardless of
  1247. >>> # their type if check_dtype=False.
  1248. >>> torch.testing.assert_close(1.0, 1, check_dtype=False)
  1249. >>> # NaN != NaN by default.
  1250. >>> expected = torch.tensor(float("Nan"))
  1251. >>> actual = expected.clone()
  1252. >>> torch.testing.assert_close(actual, expected)
  1253. Traceback (most recent call last):
  1254. ...
  1255. AssertionError: Scalars are not close!
  1256. <BLANKLINE>
  1257. Absolute difference: nan (up to 1e-05 allowed)
  1258. Relative difference: nan (up to 1.3e-06 allowed)
  1259. >>> torch.testing.assert_close(actual, expected, equal_nan=True)
  1260. >>> expected = torch.tensor([1.0, 2.0, 3.0])
  1261. >>> actual = torch.tensor([1.0, 4.0, 5.0])
  1262. >>> # The default error message can be overwritten.
  1263. >>> torch.testing.assert_close(actual, expected, msg="Argh, the tensors are not close!")
  1264. Traceback (most recent call last):
  1265. ...
  1266. AssertionError: Argh, the tensors are not close!
  1267. >>> # If msg is a callable, it can be used to augment the generated message with
  1268. >>> # extra information
  1269. >>> torch.testing.assert_close(
  1270. ... actual, expected, msg=lambda msg: f"Header\n\n{msg}\n\nFooter"
  1271. ... )
  1272. Traceback (most recent call last):
  1273. ...
  1274. AssertionError: Header
  1275. <BLANKLINE>
  1276. Tensor-likes are not close!
  1277. <BLANKLINE>
  1278. Mismatched elements: 2 / 3 (66.7%)
  1279. Greatest absolute difference: 2.0 at index (1,) (up to 1e-05 allowed)
  1280. Greatest relative difference: 1.0 at index (1,) (up to 1.3e-06 allowed)
  1281. <BLANKLINE>
  1282. Footer
  1283. """
  1284. # Hide this function from `pytest`'s traceback
  1285. __tracebackhide__ = True
  1286. error_metas = not_close_error_metas(
  1287. actual,
  1288. expected,
  1289. pair_types=(
  1290. NonePair,
  1291. BooleanPair,
  1292. NumberPair,
  1293. TensorLikePair,
  1294. ),
  1295. allow_subclasses=allow_subclasses,
  1296. rtol=rtol,
  1297. atol=atol,
  1298. equal_nan=equal_nan,
  1299. check_device=check_device,
  1300. check_dtype=check_dtype,
  1301. check_layout=check_layout,
  1302. check_stride=check_stride,
  1303. msg=msg,
  1304. )
  1305. if error_metas:
  1306. # TODO: compose all metas into one AssertionError
  1307. raise error_metas[0].to_error(msg)
  1308. def assert_allclose(
  1309. actual: Any,
  1310. expected: Any,
  1311. rtol: Optional[float] = None,
  1312. atol: Optional[float] = None,
  1313. equal_nan: bool = True,
  1314. msg: str = "",
  1315. ) -> None:
  1316. """
  1317. .. warning::
  1318. :func:`torch.testing.assert_allclose` is deprecated since ``1.12`` and will be removed in a future release.
  1319. Please use :func:`torch.testing.assert_close` instead. You can find detailed upgrade instructions
  1320. `here <https://github.com/pytorch/pytorch/issues/61844>`_.
  1321. """
  1322. warnings.warn(
  1323. "`torch.testing.assert_allclose()` is deprecated since 1.12 and will be removed in a future release. "
  1324. "Please use `torch.testing.assert_close()` instead. "
  1325. "You can find detailed upgrade instructions in https://github.com/pytorch/pytorch/issues/61844.",
  1326. FutureWarning,
  1327. stacklevel=2,
  1328. )
  1329. if not isinstance(actual, torch.Tensor):
  1330. actual = torch.tensor(actual)
  1331. if not isinstance(expected, torch.Tensor):
  1332. expected = torch.tensor(expected, dtype=actual.dtype)
  1333. if rtol is None and atol is None:
  1334. rtol, atol = default_tolerances(
  1335. actual,
  1336. expected,
  1337. dtype_precisions={
  1338. torch.float16: (1e-3, 1e-3),
  1339. torch.float32: (1e-4, 1e-5),
  1340. torch.float64: (1e-5, 1e-8),
  1341. },
  1342. )
  1343. torch.testing.assert_close(
  1344. actual,
  1345. expected,
  1346. rtol=rtol,
  1347. atol=atol,
  1348. equal_nan=equal_nan,
  1349. check_device=True,
  1350. check_dtype=False,
  1351. check_stride=False,
  1352. msg=msg or None,
  1353. )