123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577 |
- import torch
- from torch import Tensor
- import itertools
- from torch.utils._python_dispatch import TorchDispatchMode
- from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
- from functools import partial
- from torch.utils._mode_utils import no_dispatch, all_same_mode
- import torch.autograd.forward_ad as fwAD
- from typing import Callable
- import re
- def check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor):
- elem = wrapper_tensor.elem
- metadata_wrapper_tensor = metadata_accessor(wrapper_tensor)
- metadata_elem = metadata_accessor(elem)
- if metadata_wrapper_tensor == metadata_elem:
- return
- raise RuntimeError(
- f"This operator is not Composite Compliant: the "
- f"{metadata_name} of the tensor was modified directly without "
- f"going through the PyTorch dispatcher.")
- def check_metadata_consistency(wrapper_tensor, CCT):
- # CCT: CompositeCompliantTensor class which is generated using generate_cct
- if not isinstance(wrapper_tensor, CCT):
- return
- things_to_check = {
- 'shape': Tensor.size,
- 'dtype': lambda x: x.dtype,
- 'device': lambda x: x.device,
- 'numel': Tensor.numel,
- 'stride': Tensor.stride,
- 'storage_offset': Tensor.storage_offset,
- }
- for metadata_name, metadata_accessor in things_to_check.items():
- check_attr_consistency(wrapper_tensor, metadata_name, metadata_accessor)
- def is_view_fn(func):
- return func.overloadpacket.__name__ in {
- 'as_strided',
- 'detach',
- 'diagonal',
- 'expand',
- 'expand_as',
- 'movedim',
- 'narrow',
- 'permute',
- 'select',
- 'squeeze',
- 'transpose',
- 't',
- 'real',
- 'imag',
- 'view_as_real',
- 'view_as_complex',
- 'unflatten',
- 'unfold',
- 'unsqueeze',
- 'view',
- 'view_as',
- 'unbind',
- 'split',
- 'split_with_sizes',
- 'vsplit',
- 'hsplit',
- 'tensor_split',
- 'chunk',
- 'swapaxes',
- 'slice',
- '_reshape_alias',
- '_unsafe_view',
- '_conj',
- 'alias',
- }
- # manually populated from native_functions that have inplace_view: True.
- # In the future we will probably be able to grab that list directly
- def is_inplace_view_fn(func):
- return func.overloadpacket.__name__ in {
- 'as_strided_',
- 'detach_',
- 'squeeze_',
- 'swapaxes_',
- 'swapdims_',
- 't_',
- 'transpose_',
- 'unsqueeze_',
- }
- # Introspection please save us
- def is_inplace(func):
- name = func.overloadpacket.__name__
- if re.match('__i.+__', name):
- return True
- if re.match('__.+__', name):
- return False
- return name[-1] == '_'
- def generate_cct_and_mode(autograd_view_consistency=True):
- # This function returns a new class CompositeCompliantTensor
- # The two arguments control the behaviour described below.
- # autograd_view_consistency:
- # If True, alias result using `set_` if func returns a view
- # (See Note [Alias Result]).
- # Since Forward AD doesn't work with `set_`
- # we disable it by setting alias to False.
- class CompositeCompliantTensor(torch.Tensor):
- elem: torch.Tensor
- __slots__ = ['elem']
- __torch_function__ = torch._C._disabled_torch_function_impl
- @staticmethod
- def __new__(cls, elem, mode, *args, **kwargs):
- assert type(elem) is not cls, \
- "Wrapping a CompositeCompliantTensor in a CompositeCompliantTensor is not supported"
- # The storage of CompositeCompliantTensor should never be used directly
- # by a Composite operation; if the Composite
- # operator attempts to read from the storage without dispatching then it'll
- # raise a RuntimeError due to it being a meta storage.
- r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
- cls, elem.size(),
- dtype=elem.dtype, layout=elem.layout,
- device=elem.device, requires_grad=elem.requires_grad,
- strides=elem.stride(), storage_offset=elem.storage_offset())
- if elem.requires_grad:
- # CompositeCompliantTensor steals the "requires_grad"-ness.
- # Why a new copy of `elem`? Because sometimes OpInfo shares inputs between tests...
- tmp = torch.empty_strided(elem.shape, elem.stride(), dtype=elem.dtype,
- device=elem.device, layout=elem.layout,
- requires_grad=False)
- tmp.copy_(elem.detach())
- r.elem = tmp
- else:
- r.elem = elem
- assert r.stride() == r.elem.stride()
- # Propagate conjugate bits to the wrapper tensor
- # Ref: https://github.com/albanD/subclass_zoo/issues/24
- # Ref: https://github.com/albanD/subclass_zoo/issues/21
- torch._C._set_conj(r, r.elem.is_conj())
- torch._C._set_neg(r, r.elem.is_neg())
- r.mode = mode
- return r
- def __repr__(self):
- return f"CompositeCompliantTensor({self.elem})"
- @classmethod
- def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
- all_args = tree_flatten(args)[0] + tree_flatten(kwargs)[0]
- modes = tuple(e.mode for e in all_args if isinstance(e, CompositeCompliantTensor))
- if not all_same_mode(modes):
- raise RuntimeError("Multiple CompositeCompliantTensorModes NYI")
- with modes[0]:
- return func(*args, **kwargs)
- class CompositeCompliantTensorMode(TorchDispatchMode):
- def __torch_dispatch__(self, func, types, args=(), kwargs=None):
- def unwrap(e):
- return e.elem if isinstance(e, CompositeCompliantTensor) else e
- def wrap(e):
- return CompositeCompliantTensor(e, self) if isinstance(e, torch.Tensor) else e
- if func == torch.ops.aten._local_scalar_dense.default:
- raise RuntimeError(
- ".item() is not allowed to be called inside of composite "
- "functions in the PyTorch library because not all backends "
- "and/or Tensor subclasses (e.g. vmap, ProxyTensor) support them.")
- if func.overloadpacket.__name__ in ('set_', 'resize_'):
- raise RuntimeError(
- f"{func.__name__} is not allowed to be called inside of "
- f"Composite operators.")
- if is_inplace(func):
- # NB: We are making an assumption that if the function is in-place,
- # then the first argument is being written to. Introspection please save us!
- mutated_argument = args[0]
- if not isinstance(mutated_argument, CompositeCompliantTensor) and \
- any([isinstance(a, CompositeCompliantTensor) for a in args[1:]]):
- raise RuntimeError(
- 'Not composite compliant: performing in-place operation '
- f'{func.__name__} where the Tensor being written to is '
- 'regular Tensor but the other tensors are Tensor Subclasses. '
- 'Please try to avoid this in-place operation.')
- unwrapped_args = tree_map(unwrap, args)
- unwrapped_kwargs = tree_map(unwrap, kwargs)
- unwrapped_rs = func(*unwrapped_args, **unwrapped_kwargs)
- rs = tree_map(wrap, unwrapped_rs)
- if is_view_fn(func) and autograd_view_consistency:
- # Note [Alias Result]
- # Autograd asserts that for B = A.view_fn(...), B and A's storages
- # are the same. Here we try to make B alias A to avoid those asserts.
- # See https://github.com/pytorch/pytorch/issues/65339 for more information
- # about the issue.
- with no_dispatch():
- # Idea: this is a weird way of getting a storage that aliases the input.
- # This is a workaround for #65339.
- # 1. under no_dispatch, all of the wrapper tensors look like regular
- # tensors with special storage (the storage is nullptr and
- # advertises CPU/CUDA device.
- # 2. we run func, which ends up running the view operation
- # 3. All view operations reuse the input's storage and return
- # result Tensor(s) with new sizes/strides/offset that alias
- # the input.
- # 4. we set the storage (and sizes/strides/offset) of the wrapper
- # tensor results to be that of the tensors that alias the input
- result = func(*args, **kwargs)
- if isinstance(result, (tuple, list)):
- for a, b in zip(rs, result):
- a.set_(b)
- else:
- rs.set_(result)
- # Some operations are allowed to in-place modify the metadata of the
- # inputs. The only ones are the "inplace view functions"; when we
- # run into these, we manually modify the metadata of the input.
- with no_dispatch():
- if is_inplace_view_fn(func):
- func(*args, **kwargs)
- # For each CompositeCompliantTensor t, we check that t and t.elem
- # have consistent metadata. If they don't have consistent metadata,
- # that means the operator did something fishy.
- check = partial(check_metadata_consistency, CCT=CompositeCompliantTensor)
- tree_map(check, args)
- tree_map(check, kwargs)
- tree_map(check, rs)
- return rs
- return CompositeCompliantTensor, CompositeCompliantTensorMode()
- def is_tensorlist(lst):
- if not isinstance(lst, list) and not isinstance(lst, tuple):
- return False
- if len(lst) == 0:
- return False
- all_tensors = all([isinstance(elt, torch.Tensor) for elt in lst])
- if all_tensors:
- return True
- exists_one_tensor = all([isinstance(elt, torch.Tensor) for elt in lst])
- if exists_one_tensor:
- raise RuntimeError('This test assumes that PyTorch APIs cannot take '
- 'mixed lists of Tensor and other things')
- return False
- def maybe_map(fn, should_map, arg):
- return fn(arg) if should_map else arg
- def wrap(arg, CCT, cct_mode):
- # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
- if isinstance(arg, torch.Tensor):
- return CCT(arg, cct_mode)
- if is_tensorlist(arg):
- return [CCT(a, cct_mode) for a in arg]
- raise RuntimeError("wrap assumes that the input can be wrapped")
- # Given a list of flat arguments, some of which may be Tensors, return all
- # possible ways some of the arguments could be CompositeCompliantTensors (CCT).
- # For example, given Tensors A, B, C and flat_args = [A, 1, B],
- # We would return the following 4 options:
- # [CCT(A), 1, CCT(B)]
- # [CCT(A), 1, B]
- # [A, 1, CCT(B)]
- # [A, 1, B]
- # NB: Yes, this is exponential. No, we don't care too much because PyTorch ops
- # don't accept that many input Tensors.
- def generate_subclass_choices(flat_args, CCT, cct_mode):
- # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
- is_tensor_likes = [isinstance(arg, torch.Tensor) or is_tensorlist(arg) for arg in flat_args]
- subclass_options = [[False, True] if is_tensor_like else [False] for is_tensor_like in is_tensor_likes]
- for which_args_are_wrapped in itertools.product(*subclass_options):
- result = [maybe_map(partial(wrap, CCT=CCT, cct_mode=cct_mode), should_wrap_arg, arg)
- for should_wrap_arg, arg in zip(which_args_are_wrapped, flat_args)]
- yield result, which_args_are_wrapped
- # For an operation f(*args, **kwargs), each Tensor argument may either be
- # a regular Tensor or a Tensor Subclass. This iterator iterates through
- # all of those options.
- def generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
- # CCT: CompositeCompliantTensor class which is generated using generate_cct_and_mode
- flat_kwargs, spec = tree_flatten(kwargs)
- flat_args_kwargs = list(args) + list(flat_kwargs)
- for choice, debug_metadata in generate_subclass_choices(flat_args_kwargs, CCT, cct_mode):
- new_args = choice[:len(args)]
- new_kwargs = tree_unflatten(choice[len(args):], spec)
- which_args_are_wrapped = debug_metadata[:len(args)]
- which_kwargs_are_wrapped = tree_unflatten(debug_metadata[len(args):], spec)
- yield new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped
- def raise_composite_compliance_error(err, additional_info=''):
- raise RuntimeError(
- "Composite compliance check failed with "
- "the above error.\n"
- f"{additional_info}"
- "If you are adding an OpInfo of an "
- "existing operator, please feel free to skip this test "
- "because the problem was pre-existing and file an issue. "
- "Otherwise, if you added a new operator, please read "
- "through the Composite Compliance section in "
- "aten/src/ATen/native/README.md for how to resolve this. "
- ) from err
- # This test checks ALL possible permutations of calling `op` with arguments
- # that are individually either a regular Tensor or a Tensor subclass.
- #
- # The general strategy is to wrap some Tensor args and kwargs in
- # CompositeCompliantTensor wrappers and call the operation.
- # If some composite operation does any non-compliant behavior,
- # CompositeCompliantTensor will raise an error.
- def check_all_permutations(op, args, kwargs, assert_equal_fn):
- CCT, cct_mode = generate_cct_and_mode()
- expected = op(*args, **kwargs)
- for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
- new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
- try:
- actual = op(*new_args, **new_kwargs)
- # NOTE: [What errors are Composite Compliance trying to catch?]
- #
- # There's two things we want to catch:
- # - errors that would raise within the torch_dispatch impl
- # - data_ptr accesses
- # The first is easy to filter for (we could make the error a different
- # error class), the second is always going to be a RuntimeError due to
- # how it is implemented (if you try to access the data_ptr of thex
- # wrapper Tensor, it raises you some internal RuntimeError).
- #
- # So the most general thing to catch here was RuntimeError. If you
- # are here and debugging why your test failed, it's plausible that
- # the operator itself is broken and that there are other tests failing.
- except RuntimeError as err:
- raise_composite_compliance_error(
- err,
- f"- wrapped_args: {which_args_are_wrapped}\n"
- f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
- )
- def unwrap(e):
- return e.elem if isinstance(e, CCT) else e
- assert_equal_fn(tree_map(unwrap, actual), expected)
- # Checks via the usage of torch dispatch mode certain anti-patterns that
- # are not composite compliant.
- #
- # In particular, the anti-pattern we are trying to prevent is a user
- # creating an empty tensor and then resize_-ing it. Torch Dispatch Mode helps
- # here because all factory functions will create tensors that are
- # CompositeCompliantTensor.
- #
- # The general strategy is to wrap all Tensor args and kwargs in
- # CompositeCompliantTensor wrappers. If an operator that is
- # Composite does any non-compliant behavior,
- # CompositeCompliantTensor will raise an error.
- def check_with_mode(op, args, kwargs, assert_equal_fn):
- CCT, cct_mode = generate_cct_and_mode()
- def wrap(e):
- return CCT(e, cct_mode) if isinstance(e, torch.Tensor) else e
- expected = op(*args, **kwargs)
- args = tree_map(wrap, args)
- kwargs = tree_map(wrap, kwargs)
- try:
- with cct_mode:
- actual = op(*args, **kwargs)
- # see NOTE: [What errors are Composite Compliance trying to catch?]
- except RuntimeError as err:
- raise_composite_compliance_error(err)
- def unwrap(e):
- return e.elem if isinstance(e, CCT) else e
- assert_equal_fn(tree_map(unwrap, actual), expected)
- def gather_leaf_tensors(args, kwargs):
- leaf_tensors = []
- args, args_spec = tree_flatten(args)
- kwargs, kwargs_spec = tree_flatten(kwargs)
- args = args + kwargs
- for arg in args:
- if not isinstance(arg, torch.Tensor):
- continue
- if arg.requires_grad:
- leaf_tensors.append(arg)
- return leaf_tensors
- def compute_expected_grads(op, args, kwargs, output_process_fn_grad=None, gradcheck_wrapper=None):
- if gradcheck_wrapper is None:
- results = op(*args, **kwargs)
- else:
- results = gradcheck_wrapper(op, *args, **kwargs)
- if output_process_fn_grad is not None:
- results = output_process_fn_grad(results)
- flat_results, _ = tree_flatten(results)
- flat_diff_results = [r for r in flat_results if r.requires_grad]
- assert len(flat_diff_results) > 0
- grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype) for r in flat_diff_results]
- leaf_tensors = gather_leaf_tensors(args, kwargs)
- assert len(leaf_tensors) > 0
- return torch.autograd.grad(flat_diff_results, leaf_tensors,
- grads, allow_unused=True, retain_graph=True)
- # Checks if the backward formula is composite compliant by testing
- # all possible permutations of {inputs, grad_outputs} being
- # CompositeCompliantTensor or regular Tensors.
- #
- # NB: it is important that op is accepted as a Callable and not an OpInfo,
- # this means we can apply check_backward_formula to things that aren't OpInfos
- # while debugging.
- def check_backward_formula(op: Callable, args, kwargs,
- output_process_fn_grad=None,
- gradcheck_wrapper=None, assert_equal_fn=None):
- CCT, cct_mode = generate_cct_and_mode()
- expected = compute_expected_grads(op, args, kwargs, output_process_fn_grad, gradcheck_wrapper)
- for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
- new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
- leaf_tensors = gather_leaf_tensors(new_args, new_kwargs)
- assert len(leaf_tensors) > 0
- try:
- if gradcheck_wrapper is None:
- results = op(*new_args, **new_kwargs)
- else:
- results = gradcheck_wrapper(op, *new_args, **new_kwargs)
- if output_process_fn_grad is not None:
- results = output_process_fn_grad(results)
- # see NOTE: [What errors are Composite Compliance trying to catch?]
- except RuntimeError as err:
- raise_composite_compliance_error(
- err,
- f"- wrapped_args: {which_args_are_wrapped}\n"
- f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
- )
- flat_results, _ = tree_flatten(results)
- flat_diff_results = [r for r in flat_results if r.requires_grad]
- assert len(flat_diff_results) > 0
- # NB: ones, not ones_like, so we get a regular Tensor here
- grads = [torch.ones(r.shape, device=r.device, dtype=r.dtype)
- for r in flat_diff_results]
- for flat_new_grads, which_grad_is_batched in generate_subclass_choices(grads, CCT, cct_mode):
- try:
- actual = torch.autograd.grad(flat_diff_results, leaf_tensors, flat_new_grads,
- allow_unused=True, retain_graph=True)
- # see NOTE: [What errors are Composite Compliance trying to catch?]
- except RuntimeError as err:
- raise_composite_compliance_error(
- err,
- f"- wrapped_args: {which_args_are_wrapped}\n"
- f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
- f"- wrapped_grads: {which_grad_is_batched}\n"
- )
- def unwrap(e):
- return e.elem if isinstance(e, CCT) else e
- assert_equal_fn(tuple(map(unwrap, actual)), expected, equal_nan=True)
- # Checks if the forward AD formula is composite compliant by testing
- # all possible permutations of {primals, tangents} being
- # CompositeCompliantTensor or regular Tensors.
- #
- # NB: it is important that op is accepted as a Callable and not an OpInfo,
- # this means we can apply check_forward_ad_formula to things that aren't OpInfos
- # while debugging.
- def check_forward_ad_formula(op: Callable, args, kwargs, gradcheck_wrapper=None, assert_equal_fn=None):
- CCT, cct_mode = generate_cct_and_mode(autograd_view_consistency=False)
- def maybe_tangent(t):
- assert type(t) is not CCT
- # Generate `tangent` tensor
- # if given object is a Tensor and requires grad is set.
- if isinstance(t, torch.Tensor) and t.requires_grad:
- return torch.randn_like(t)
- elif is_tensorlist(t):
- return [torch.randn_like(e) if e.requires_grad else None for e in t]
- return None
- tangent_args = tuple(maybe_tangent(arg) for arg in args)
- flat_kwargs, spec = tree_flatten(kwargs)
- flat_tangent_kwargs = tuple(maybe_tangent(arg) for arg in flat_kwargs)
- tangent_kwargs = tree_unflatten(flat_tangent_kwargs, spec)
- with fwAD.dual_level():
- def maybe_make_dual(dual):
- # Returns dual tensor if primal is a tensor/tensor subclass
- # with requires_grad set.
- primal, tangent = dual
- if isinstance(primal, torch.Tensor) and primal.requires_grad:
- return fwAD.make_dual(primal.detach(), tangent)
- elif is_tensorlist(primal):
- return tuple(fwAD.make_dual(pri.detach(), tang) if tang is not None else pri
- for pri, tang in zip(primal, tangent))
- return primal
- def compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs):
- op_args = tuple(map(maybe_make_dual, zip(args, tangent_args)))
- op_kwargs = {k: maybe_make_dual((v, tangent_kwargs[k])) for k, v in kwargs.items()}
- if gradcheck_wrapper is None:
- return op(*op_args, **op_kwargs)
- return gradcheck_wrapper(op, *op_args, **op_kwargs)
- expected = compute_expected_grad(args, tangent_args, kwargs, tangent_kwargs)
- expected = tree_map(fwAD.unpack_dual, expected)
- expected_primals = tree_map(lambda x: x.primal, expected)
- expected_tangents = tree_map(lambda x: x.tangent, expected)
- # Permutations of arg and kwargs in CCT.
- for choice in generate_subclass_choices_args_kwargs(args, kwargs, CCT, cct_mode):
- new_args, new_kwargs, which_args_are_wrapped, which_kwargs_are_wrapped = choice
- # Permutations tangent arg and tangent kwargs in CCT.
- for tang_choice in generate_subclass_choices_args_kwargs(tangent_args, tangent_kwargs, CCT, cct_mode):
- new_tang_args, new_tang_kwargs, \
- which_tang_args_are_wrapped, which_tang_kwargs_are_wrapped = tang_choice
- op_args = tuple(map(maybe_make_dual, zip(new_args, new_tang_args)))
- op_kwargs = {k: maybe_make_dual((v, new_tang_kwargs[k])) for k, v in new_kwargs.items()}
- try:
- if gradcheck_wrapper is None:
- actual = op(*op_args, **op_kwargs)
- else:
- actual = gradcheck_wrapper(op, *op_args, **op_kwargs)
- # see NOTE: [What errors are Composite Compliance trying to catch?]
- except RuntimeError as err:
- raise_composite_compliance_error(
- err,
- f"- wrapped_args: {which_args_are_wrapped}\n"
- f"- wrapped_kwargs: {which_kwargs_are_wrapped}\n"
- f"- wrapped_tangent_args: {which_tang_args_are_wrapped}\n"
- f"- wrapped_tangent_kwargs: {which_tang_kwargs_are_wrapped}\n"
- )
- def unwrap(e):
- return e.elem if isinstance(e, CCT) else e
- actual = tree_map(fwAD.unpack_dual, actual)
- actual_primals = tree_map(lambda x: unwrap(x.primal), actual)
- actual_tangents = tree_map(lambda x: unwrap(x.tangent), actual)
- assert_equal_fn(actual_primals, expected_primals, equal_nan=True)
- assert_equal_fn(actual_tangents, expected_tangents, equal_nan=True)
|