composite_compliance.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577
  1. import torch
  2. from torch import Tensor
  3. import itertools
  4. from torch.utils._python_dispatch import TorchDispatchMode
  5. from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
  6. from functools import partial
  7. from torch.utils._mode_utils import no_dispatch, all_same_mode
  8. import torch.autograd.forward_ad as fwAD
  9. from typing import Callable
  10. import re
  11. def check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor):
  12. elem = wrapper_tensor.elem
  13. metadata_wrapper_tensor = metadata_accessor(wrapper_tensor)
  14. metadata_elem = metadata_accessor(elem)
  15. if metadata_wrapper_tensor == metadata_elem:
  16. return
  17. raise RuntimeError(
  18. f"This operator is not Composite Compliant: the "
  19. f"{metadata_name} of the tensor was modified directly without "
  20. f"going through the PyTorch dispatcher.")
  21. def check_metadata_consistency(wrapper_tensor, CCT):
  22. # CCT: CompositeCompliantTensor class which is generated using generate_cct
  23. if not isinstance(wrapper_tensor, CCT):
  24. return
  25. things_to_check = {
  26. 'shape': Tensor.size,
  27. 'dtype': lambda x: x.dtype,
  28. 'device': lambda x: x.device,
  29. 'numel': Tensor.numel,
  30. 'stride': Tensor.stride,
  31. 'storage_offset': Tensor.storage_offset,
  32. }
  33. for metadata_name, metadata_accessor in things_to_check.items():
  34. check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor)
  35. def is_view_fn(func):
  36. return func.overloadpacket.__name__ in {
  37. 'as_strided',
  38. 'detach',
  39. 'diagonal',
  40. 'expand',
  41. 'expand_as',
  42. 'movedim',
  43. 'narrow',
  44. 'permute',
  45. 'select',
  46. 'squeeze',
  47. 'transpose',
  48. 't',
  49. 'real',
  50. 'imag',
  51. 'view_as_real',
  52. 'view_as_complex',
  53. 'unflatten',
  54. 'unfold',
  55. 'unsqueeze',
  56. 'view',
  57. 'view_as',
  58. 'unbind',
  59. 'split',
  60. 'split_with_sizes',
  61. 'vsplit',
  62. 'hsplit',
  63. 'tensor_split',
  64. 'chunk',
  65. 'swapaxes',
  66. 'slice',
  67. '_reshape_alias',
  68. '_unsafe_view',
  69. '_conj',
  70. 'alias',
  71. }
  72. # manually populated from native_functions that have inplace_view: True.
  73. # In the future we will probably be able to grab that list directly
  74. def is_inplace_view_fn(func):
  75. return func.overloadpacket.__name__ in {
  76. 'as_strided_',
  77. 'detach_',
  78. 'squeeze_',
  79. 'swapaxes_',
  80. 'swapdims_',
  81. 't_',
  82. 'transpose_',
  83. 'unsqueeze_',
  84. }
  85. # Introspection please save us
  86. def is_inplace(func):
  87. name = func.overloadpacket.__name__
  88. if re.match('__i.+__', name):
  89. return True
  90. if re.match('__.+__', name):
  91. return False
  92. return name[-1] == '_'
  93. def generate_cct_and_mode(autograd_view_consistency=True):
  94. # This function returns a new class CompositeCompliantTensor
  95. # The two arguments control the behaviour described below.
  96. # autograd_view_consistency:
  97. # If True, alias result using `set_` if func returns a view
  98. # (See Note [Alias Result]).
  99. # Since Forward AD doesn't work with `set_`
  100. # we disable it by setting alias to False.
  101. class CompositeCompliantTensor(torch.Tensor):
  102. elem: torch.Tensor
  103. __slots__ = ['elem']
  104. __torch_function__ = torch._C._disabled_torch_function_impl
  105. @staticmethod
  106. def __new__(cls, elem, mode, *args, **kwargs):
  107. assert type(elem) is not cls, \
  108. "Wrapping a CompositeCompliantTensor in a CompositeCompliantTensor is not supported"
  109. # The storage of CompositeCompliantTensor should never be used directly
  110. # by a Composite operation; if the Composite
  111. # operator attempts to read from the storage without dispatching then it'll
  112. # raise a RuntimeError due to it being a meta storage.
  113. r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
  114. cls, elem.size(),
  115. dtype=elem.dtype, layout=elem.layout,
  116. device=elem.device, requires_grad=elem.requires_grad,
  117. strides=elem.stride(), storage_offset=elem.storage_offset())
  118. if elem.requires_grad:
  119. # CompositeCompliantTensor steals the "requires_grad"-ness.
  120. # Why a new copy of `elem`? Because sometimes OpInfo shares inputs between tests...
  121. tmp = torch.empty_strided(elem.shape, elem.stride(), dtype=elem.dtype,
  122. device=elem.device, layout=elem.layout,
  123. requires_grad=False)
  124. tmp.copy_(elem.detach())
  125. r.elem = tmp
  126. else:
  127. r.elem = elem
  128. assert r.stride() == r.elem.stride()
  129. # Propagate conjugate bits to the wrapper tensor
  130. # Ref: https://github.com/albanD/subclass_zoo/issues/24
  131. # Ref: https://github.com/albanD/subclass_zoo/issues/21
  132. torch._C._set_conj(r, r.elem.is_conj())
  133. torch._C._set_neg(r, r.elem.is_neg())
  134. r.mode = mode
  135. return r
  136. def __repr__(self):
  137. return f"CompositeCompliantTensor({self.elem})"
  138. @classmethod
  139. def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
  140. all_args = tree_flatten(args)[0] + tree_flatten(kwargs)[0]
  141. modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor))
  142. if not all_same_mode(modes):
  143. raise RuntimeError("Multiple CompositeCompliantTensorModes NYI")
  144. with modes[0]:
  145. return func(*args, **kwargs)
  146. class CompositeCompliantTensorMode(TorchDispatchMode):
  147. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  148. def unwrap(e):
  149. return e.elem if isinstance(e, CompositeCompliantTensor) else e
  150. def wrap(e):
  151. return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e
  152. if func == torch.ops.aten._local_scalar_dense.default:
  153. raise RuntimeError(
  154. ".item() is not allowed to be called inside of composite "
  155. "functions in the PyTorch library because not all backends "
  156. "and/or Tensor subclasses (e.g. vmap, ProxyTensor) support them.")
  157. if func.overloadpacket.__name__ in ('set_', 'resize_'):
  158. raise RuntimeError(
  159. f"{func.__name__} is not allowed to be called inside of "
  160. f"Composite operators.")
  161. if is_inplace(func):
  162. # NB: We are making an assumption that if the function is in-place,
  163. # then the first argument is being written to. Introspection please save us!
  164. mutated_argument = args[0]
  165. if not isinstance(mutated_argument, CompositeCompliantTensor) and \
  166. any([isinstance(a, CompositeCompliantTensor) for a in args[1:]]):
  167. raise RuntimeError(
  168. 'Not composite compliant: performing in-place operation '
  169. f'{func.__name__} where the Tensor being written to is '
  170. 'regular Tensor but the other tensors are Tensor Subclasses. '
  171. 'Please try to avoid this in-place operation.')
  172. unwrapped_args = tree_map(unwrap, args)
  173. unwrapped_kwargs = tree_map(unwrap, kwargs)
  174. unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs)
  175. rs = tree_map(wrap, unwrapped_rs)
  176. if is_view_fn(func) and autograd_view_consistency:
  177. # Note [Alias Result]
  178. # Autograd asserts that for B = A.view_fn(...), B and A's storages
  179. # are the same. Here we try to make B alias A to avoid those asserts.
  180. # See https://github.com/pytorch/pytorch/issues/65339 for more information
  181. # about the issue.
  182. with no_dispatch():
  183. # Idea: this is a weird way of getting a storage that aliases the input.
  184. # This is a workaround for #65339.
  185. # 1. under no_dispatch, all of the wrapper tensors look like regular
  186. # tensors with special storage (the storage is nullptr and
  187. # advertises CPU/CUDA device.
  188. # 2. we run func, which ends up running the view operation
  189. # 3. All view operations reuse the input's storage and return
  190. # result Tensor(s) with new sizes/strides/offset that alias
  191. # the input.
  192. # 4. we set the storage (and sizes/strides/offset) of the wrapper
  193. # tensor results to be that of the tensors that alias the input
  194. result = func(*args, **kwargs)
  195. if isinstance(result, (tuple, list)):
  196. for a, b in zip(rs, result):
  197. a.set_(b)
  198. else:
  199. rs.set_(result)
  200. # Some operations are allowed to in-place modify the metadata of the
  201. # inputs. The only ones are the "inplace view functions"; when we
  202. # run into these, we manually modify the metadata of the input.
  203. with no_dispatch():
  204. if is_inplace_view_fn(func):
  205. func(*args, **kwargs)
  206. # For each CompositeCompliantTensor t, we check that t and t.elem
  207. # have consistent metadata. If they don't have consistent metadata,
  208. # that means the operator did something fishy.
  209. check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor)
  210. tree_map(check, args)
  211. tree_map(check, kwargs)
  212. tree_map(check, rs)
  213. return rs
  214. return CompositeCompliantTensor, CompositeCompliantTensorMode()
  215. def is_tensorlist(lst):
  216. if not isinstance(lst, list) and not isinstance(lst, tuple):
  217. return False
  218. if len(lst) == 0:
  219. return False
  220. all_tensors = all([isinstance(elt, torch.Tensor) for elt in lst])
  221. if all_tensors:
  222. return True
  223. exists_one_tensor = all([isinstance(elt, torch.Tensor) for elt in lst])
  224. if exists_one_tensor:
  225. raise RuntimeError('This test assumes that PyTorch APIs cannot take '
  226. 'mixed lists of Tensor and other things')
  227. return False
  228. def maybe_map(fn, should_map, arg):
  229. return fn(arg) if should_map else arg
  230. def wrap(arg, CCT, cct_mode):
  231. # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
  232. if isinstance(arg, torch.Tensor):
  233. return CCT(arg, cct_mode)
  234. if is_tensorlist(arg):
  235. return [CCT(a, cct_mode) for a in arg]
  236. raise RuntimeError("wrap assumes that the input can be wrapped")
  237. # Given a list of flat arguments, some of which may be Tensors, return all
  238. # possible ways some of the arguments could be CompositeCompliantTensors (CCT).
  239. # For example, given Tensors A, B, C and flat_args = [A, 1, B],
  240. # We would return the following 4 options:
  241. # [CCT(A), 1, CCT(B)]
  242. # [CCT(A), 1, B]
  243. # [A, 1, CCT(B)]
  244. # [A, 1, B]
  245. # NB: Yes, this is exponential. No, we don't care too much because PyTorch ops
  246. # don't accept that many input Tensors.
  247. def generate_subclass_choices(flat_args, CCT, cct_mode):
  248. # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
  249. is_tensor_likes = [isinstance(arg, torch.Tensor) or is_tensorlist(arg) for arg in flat_args]
  250. subclass_options = [[False, True] if is_tensor_like else [False] for is_tensor_like in is_tensor_likes]
  251. for which_args_are_wrapped in itertools.product(*subclass_options):
  252. result = [maybe_map(partial(wrap, CCT=CCT, cct_mode=cct_mode), should_wrap_arg, arg)
  253. for should_wrap_arg, arg in zip(which_args_are_wrapped, flat_args)]
  254. yield result, which_args_are_wrapped
  255. # For an operation f(*args, **kwargs), each Tensor argument may either be
  256. # a regular Tensor or a Tensor Subclass. This iterator iterates through
  257. # all of those options.
  258. def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
  259. # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
  260. flat_kwargs, spec = tree_flatten(kwargs)
  261. flat_args_kwargs = list(args) + list(flat_kwargs)
  262. for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT, cct_mode):
  263. new_args = choice[:len(args)]
  264. new_kwargs = tree_unflatten(choice[len(args):], spec)
  265. which_args_are_wrapped = debug_metadata[:len(args)]
  266. which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):], spec)
  267. yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped
  268. def raise_composite_compliance_error(err, additional_info=''):
  269. raise RuntimeError(
  270. "Composite compliance check failed with "
  271. "the above error.\n"
  272. f"{additional_info}"
  273. "If you are adding an OpInfo of an "
  274. "existing operator, please feel free to skip this test "
  275. "because the problem was pre-existing and file an issue. "
  276. "Otherwise, if you added a new operator, please read "
  277. "through the Composite Compliance section in "
  278. "aten/src/ATen/native/README.md for how to resolve this. "
  279. ) from err
  280. # This test checks ALL possible permutations of calling `op` with arguments
  281. # that are individually either a regular Tensor or a Tensor subclass.
  282. #
  283. # The general strategy is to wrap some Tensor args and kwargs in
  284. # CompositeCompliantTensor wrappers and call the operation.
  285. # If some composite operation does any non-compliant behavior,
  286. # CompositeCompliantTensor will raise an error.
  287. def check_all_permutations(op, args, kwargs, assert_equal_fn):
  288. CCT, cct_mode = generate_cct_and_mode()
  289. expected = op(*args, **kwargs)
  290. for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
  291. new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
  292. try:
  293. actual = op(*new_args, **new_kwargs)
  294. # NOTE: [What errors are Composite Compliance trying to catch?]
  295. #
  296. # There's two things we want to catch:
  297. # - errors that would raise within the torch_dispatch impl
  298. # - data_ptr accesses
  299. # The first is easy to filter for (we could make the error a different
  300. # error class), the second is always going to be a RuntimeError due to
  301. # how it is implemented (if you try to access the data_ptr of thex
  302. # wrapper Tensor, it raises you some internal RuntimeError).
  303. #
  304. # So the most general thing to catch here was RuntimeError. If you
  305. # are here and debugging why your test failed, it's plausible that
  306. # the operator itself is broken and that there are other tests failing.
  307. except RuntimeError as err:
  308. raise_composite_compliance_error(
  309. err,
  310. f"- wrapped_args: {which_args_are_wrapped}\n"
  311. f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
  312. )
  313. def unwrap(e):
  314. return e.elem if isinstance(e, CCT) else e
  315. assert_equal_fn(tree_map(unwrap, actual), expected)
  316. # Checks via the usage of torch dispatch mode certain anti-patterns that
  317. # are not composite compliant.
  318. #
  319. # In particular, the anti-pattern we are trying to prevent is a user
  320. # creating an empty tensor and then resize_-ing it. Torch Dispatch Mode helps
  321. # here because all factory functions will create tensors that are
  322. # CompositeCompliantTensor.
  323. #
  324. # The general strategy is to wrap all Tensor args and kwargs in
  325. # CompositeCompliantTensor wrappers. If an operator that is
  326. # Composite does any non-compliant behavior,
  327. # CompositeCompliantTensor will raise an error.
  328. def check_with_mode(op, args, kwargs, assert_equal_fn):
  329. CCT, cct_mode = generate_cct_and_mode()
  330. def wrap(e):
  331. return CCT(e, cct_mode) if isinstance(e, torch.Tensor) else e
  332. expected = op(*args, **kwargs)
  333. args = tree_map(wrap, args)
  334. kwargs = tree_map(wrap, kwargs)
  335. try:
  336. with cct_mode:
  337. actual = op(*args, **kwargs)
  338. # see NOTE: [What errors are Composite Compliance trying to catch?]
  339. except RuntimeError as err:
  340. raise_composite_compliance_error(err)
  341. def unwrap(e):
  342. return e.elem if isinstance(e, CCT) else e
  343. assert_equal_fn(tree_map(unwrap, actual), expected)
  344. def gather_leaf_tensors(args, kwargs):
  345. leaf_tensors = []
  346. args, args_spec = tree_flatten(args)
  347. kwargs, kwargs_spec = tree_flatten(kwargs)
  348. args = args + kwargs
  349. for arg in args:
  350. if not isinstance(arg, torch.Tensor):
  351. continue
  352. if arg.requires_grad:
  353. leaf_tensors.append(arg)
  354. return leaf_tensors
  355. def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradcheck_wrapper=None):
  356. if gradcheck_wrapper is None:
  357. results = op(*args, **kwargs)
  358. else:
  359. results = gradcheck_wrapper(op, *args, **kwargs)
  360. if output_process_fn_grad is not None:
  361. results = output_process_fn_grad(results)
  362. flat_results, _ = tree_flatten(results)
  363. flat_diff_results = [r for r in flat_results if r.requires_grad]
  364. assert len(flat_diff_results) > 0
  365. grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in flat_diff_results]
  366. leaf_tensors = gather_leaf_tensors(args, kwargs)
  367. assert len(leaf_tensors) > 0
  368. return torch.autograd.grad(flat_diff_results, leaf_tensors,
  369. grads, allow_unused=True, retain_graph=True)
  370. # Checks if the backward formula is composite compliant by testing
  371. # all possible permutations of {inputs, grad_outputs} being
  372. # CompositeCompliantTensor or regular Tensors.
  373. #
  374. # NB: it is important that op is accepted as a Callable and not an OpInfo,
  375. # this means we can apply check_backward_formula to things that aren't OpInfos
  376. # while debugging.
  377. def check_backward_formula(op: Callable, args, kwargs,
  378. output_process_fn_grad=None,
  379. gradcheck_wrapper=None, assert_equal_fn=None):
  380. CCT, cct_mode = generate_cct_and_mode()
  381. expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper)
  382. for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
  383. new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
  384. leaf_tensors = gather_leaf_tensors(new_args, new_kwargs)
  385. assert len(leaf_tensors) > 0
  386. try:
  387. if gradcheck_wrapper is None:
  388. results = op(*new_args, **new_kwargs)
  389. else:
  390. results = gradcheck_wrapper(op, *new_args, **new_kwargs)
  391. if output_process_fn_grad is not None:
  392. results = output_process_fn_grad(results)
  393. # see NOTE: [What errors are Composite Compliance trying to catch?]
  394. except RuntimeError as err:
  395. raise_composite_compliance_error(
  396. err,
  397. f"- wrapped_args: {which_args_are_wrapped}\n"
  398. f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
  399. )
  400. flat_results, _ = tree_flatten(results)
  401. flat_diff_results = [r for r in flat_results if r.requires_grad]
  402. assert len(flat_diff_results) > 0
  403. # NB: ones, not ones_like, so we get a regular Tensor here
  404. grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype)
  405. for r in flat_diff_results]
  406. for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT, cct_mode):
  407. try:
  408. actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads,
  409. allow_unused=True, retain_graph=True)
  410. # see NOTE: [What errors are Composite Compliance trying to catch?]
  411. except RuntimeError as err:
  412. raise_composite_compliance_error(
  413. err,
  414. f"- wrapped_args: {which_args_are_wrapped}\n"
  415. f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
  416. f"- wrapped_grads: {which_grad_is_batched}\n"
  417. )
  418. def unwrap(e):
  419. return e.elem if isinstance(e, CCT) else e
  420. assert_equal_fn(tuple(map(unwrap, actual)), expected, equal_nan=True)
  421. # Checks if the forward AD formula is composite compliant by testing
  422. # all possible permutations of {primals, tangents} being
  423. # CompositeCompliantTensor or regular Tensors.
  424. #
  425. # NB: it is important that op is accepted as a Callable and not an OpInfo,
  426. # this means we can apply check_forward_ad_formula to things that aren't OpInfos
  427. # while debugging.
  428. def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None):
  429. CCT, cct_mode = generate_cct_and_mode(autograd_view_consistency=False)
  430. def maybe_tangent(t):
  431. assert type(t) is not CCT
  432. # Generate `tangent` tensor
  433. # if given object is a Tensor and requires grad is set.
  434. if isinstance(t, torch.Tensor) and t.requires_grad:
  435. return torch.randn_like(t)
  436. elif is_tensorlist(t):
  437. return [torch.randn_like(e) if e.requires_grad else None for e in t]
  438. return None
  439. tangent_args = tuple(maybe_tangent(arg) for arg in args)
  440. flat_kwargs, spec = tree_flatten(kwargs)
  441. flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs)
  442. tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec)
  443. with fwAD.dual_level():
  444. def maybe_make_dual(dual):
  445. # Returns dual tensor if primal is a tensor/tensor subclass
  446. # with requires_grad set.
  447. primal, tangent = dual
  448. if isinstance(primal, torch.Tensor) and primal.requires_grad:
  449. return fwAD.make_dual(primal.detach(), tangent)
  450. elif is_tensorlist(primal):
  451. return tuple(fwAD.make_dual(pri.detach(), tang) if tang is not None else pri
  452. for pri, tang in zip(primal, tangent))
  453. return primal
  454. def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs):
  455. op_args = tuple(map(maybe_make_dual, zip(args, tangent_args)))
  456. op_kwargs = {k: maybe_make_dual((v, tangent_kwargs[k])) for k, v in kwargs.items()}
  457. if gradcheck_wrapper is None:
  458. return op(*op_args, **op_kwargs)
  459. return gradcheck_wrapper(op, *op_args, **op_kwargs)
  460. expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs)
  461. expected = tree_map(fwAD.unpack_dual, expected)
  462. expected_primals = tree_map(lambda x: x.primal, expected)
  463. expected_tangents = tree_map(lambda x: x.tangent, expected)
  464. # Permutations of arg and kwargs in CCT.
  465. for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
  466. new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
  467. # Permutations tangent arg and tangent kwargs in CCT.
  468. for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT, cct_mode):
  469. new_tang_args, new_tang_kwargs, \
  470. which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice
  471. op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args)))
  472. op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()}
  473. try:
  474. if gradcheck_wrapper is None:
  475. actual = op(*op_args, **op_kwargs)
  476. else:
  477. actual = gradcheck_wrapper(op, *op_args, **op_kwargs)
  478. # see NOTE: [What errors are Composite Compliance trying to catch?]
  479. except RuntimeError as err:
  480. raise_composite_compliance_error(
  481. err,
  482. f"- wrapped_args: {which_args_are_wrapped}\n"
  483. f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
  484. f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n"
  485. f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n"
  486. )
  487. def unwrap(e):
  488. return e.elem if isinstance(e, CCT) else e
  489. actual = tree_map(fwAD.unpack_dual, actual)
  490. actual_primals = tree_map(lambda x: unwrap(x.primal), actual)
  491. actual_tangents = tree_map(lambda x: unwrap(x.tangent), actual)
  492. assert_equal_fn(actual_primals, expected_primals, equal_nan=True)
  493. assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True)