| 12345678910111213141516171819202122232425262728293031323334353637383940414243 | # Copyright (c) Meta Platforms, Inc. and affiliates"""These are functions that should simply be applied to both mask and data.Take select or stack as an example. This operation can be applied toboth the mask and data of a MaskedTensor and the result wrapped intoa new MaskedTensor as a result."""import torchfrom .core import _map_mt_args_kwargs, _wrap_result__all__ = []  # type: ignore[var-annotated]PASSTHROUGH_FNS = [    torch.ops.aten.select,    torch.ops.aten.transpose,    torch.ops.aten.split,    torch.ops.aten.t,    torch.ops.aten.slice,    torch.ops.aten.slice_backward,    torch.ops.aten.select_backward,    torch.ops.aten.index,    torch.ops.aten.expand,    torch.ops.aten.view,    torch.ops.aten._unsafe_view,    torch.ops.aten._reshape_alias,    torch.ops.aten.cat,    torch.ops.aten.unsqueeze,]def _is_pass_through_fn(fn):    return fn in PASSTHROUGH_FNSdef _apply_pass_through_fn(fn, *args, **kwargs):    data_args, data_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_data())    result_data = fn(*data_args, **data_kwargs)    mask_args, mask_kwargs = _map_mt_args_kwargs(args, kwargs, lambda x: x.get_mask())    result_mask = fn(*mask_args, **mask_kwargs)    return _wrap_result(result_data, result_mask)
 |