# Copyright (c) Meta Platforms, Inc. and affiliates from functools import partial import torch from .binary import ( _apply_native_binary, NATIVE_BINARY_FNS, NATIVE_INPLACE_BINARY_FNS, ) from .core import is_masked_tensor, MaskedTensor, _get_data, _masks_match, _maybe_get_mask from .passthrough import ( _apply_pass_through_fn, PASSTHROUGH_FNS ) from .reductions import ( _apply_reduction, NATIVE_REDUCE_FNS, TORCH_REDUCE_FNS, TENSOR_REDUCE_FNS, ) from .unary import ( _apply_native_unary, NATIVE_UNARY_FNS, NATIVE_INPLACE_UNARY_FNS, ) __all__ = [] # type: ignore[var-annotated] def _check_args_kwargs_length(args, kwargs, error_prefix, len_args=None, len_kwargs=None): if len_args is not None and len_args != len(args): raise ValueError(f"{error_prefix}: len(args) must be {len_args} but got {len(args)}") if len_kwargs is not None and len_kwargs != len(kwargs): raise ValueError(f"{error_prefix}: len(kwargs) must be {len_kwargs} but got {len(kwargs)}") class _MaskedContiguous(torch.autograd.Function): @staticmethod def forward(ctx, input): if not is_masked_tensor(input): raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.") if input.is_contiguous(): return input data = input.get_data() mask = input.get_mask() return MaskedTensor(data.contiguous(), mask.contiguous()) @staticmethod def backward(ctx, grad_output): return grad_output class _MaskedToDense(torch.autograd.Function): @staticmethod def forward(ctx, input): if not is_masked_tensor(input): raise ValueError("MaskedToDense forward: input must be a MaskedTensor.") if input.layout == torch.strided: return input ctx.layout = input.layout data = input.get_data() mask = input.get_mask() return MaskedTensor(data.to_dense(), mask.to_dense()) @staticmethod def backward(ctx, grad_output): layout = ctx.layout if layout == torch.sparse_coo: return grad_output.to_sparse_coo() elif layout == torch.sparse_csr: return grad_output.to_sparse_csr() elif layout == torch.strided: return grad_output.to_dense() raise ValueError("to_dense: Unsupported input layout: ", layout) class _MaskedToSparse(torch.autograd.Function): @staticmethod def forward(ctx, input): if not is_masked_tensor(input): raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.") # Following the convention from sparse tensors that to_sparse always means that we convert to sparse_coo if input.layout == torch.sparse_coo: return input data = input.get_data() mask = input.get_mask() sparse_mask = mask.to_sparse_coo().coalesce() sparse_data = data.sparse_mask(sparse_mask) return MaskedTensor(sparse_data, sparse_mask) @staticmethod def backward(ctx, grad_output): return grad_output.to_dense() class _MaskedToSparseCsr(torch.autograd.Function): @staticmethod def forward(ctx, input): if not is_masked_tensor(input): raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.") if input._masked_data.ndim != 2: raise ValueError(f"Only 2D tensors can be converted to the SparseCsr layout but got shape: {input._masked_data.size()}") if input.layout == torch.sparse_csr: return input data = input.get_data() mask = input.get_mask() sparse_mask = mask.to_sparse_csr() sparse_data = data.sparse_mask(sparse_mask) return MaskedTensor(sparse_data, sparse_mask) @staticmethod def backward(ctx, grad_output): return grad_output.to_dense() class _MaskedWhere(torch.autograd.Function): @staticmethod def forward(ctx, cond, self, other): ctx.mark_non_differentiable(cond) ctx.save_for_backward(cond) return torch.ops.aten.where(cond, self, other) @staticmethod def backward(ctx, grad_output): (cond,) = ctx.saved_tensors def masked_out_like(mt): return MaskedTensor(mt.get_data(), torch.zeros_like(mt.get_mask()).bool()) return ( None, torch.ops.aten.where(cond, grad_output, masked_out_like(grad_output)), torch.ops.aten.where(cond, masked_out_like(grad_output), grad_output), ) _MASKEDTENSOR_FUNCTION_TABLE = {} _function_fn_apply_map = { (tuple(NATIVE_REDUCE_FNS), tuple(TORCH_REDUCE_FNS), tuple(TENSOR_REDUCE_FNS)): _apply_reduction, } for fn_map_list, apply_fn in _function_fn_apply_map.items(): for fn_map in fn_map_list: for fn in fn_map: _MASKEDTENSOR_FUNCTION_TABLE[fn] = partial(apply_fn, fn) def register_function_func(ops): """ Used for registering a new __torch_function__ function to MaskedTensor Called via _MASKEDTENSOR_FUNCTION_TABLE[func](*args, **kwargs) The code to register a new function looks like: @register_function_func(list_of_ops) def foo(func, *args, **kwargs): """ def wrapper(func): for op in ops: _MASKEDTENSOR_FUNCTION_TABLE[op] = partial(func, op) return wrapper @register_function_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS) def _general_function_reductions(func, *args, **kwargs): return _apply_reduction(func, *args, **kwargs) @register_function_func([torch.Tensor.where, torch.where]) def _function_where(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, "__torch_function__, torch.where", len_args=3, len_kwargs=0) return _MaskedWhere.apply(*args) @register_function_func([torch.Tensor.contiguous]) def _function_contiguous(func, *args, **kwargs): return _MaskedContiguous.apply(args[0]) @register_function_func([torch.Tensor.to_dense]) def _function_to_dense(func, *args, **kwargs): return _MaskedToDense.apply(args[0]) @register_function_func([torch.Tensor.to_sparse]) def _function_to_sparse(func, *args, **kwargs): return _MaskedToSparse.apply(args[0]) @register_function_func([torch.Tensor.to_sparse_csr]) def _function_to_sparse_csr(func, *args, **kwargs): return _MaskedToSparseCsr.apply(args[0]) _MASKEDTENSOR_DISPATCH_TABLE = {} def register_dispatch_func(aten_ops): """ Used for registering a new __torch_dispatch__ function to MaskedTensor Called via _MASKEDTENSOR_DISPATCH_TABLE[func](*args, **kwargs) The code to register a new function looks like: @register_dispatch_func(list_of_ops) def foo(func, *args, **kwargs): """ def wrapper(func): for aten_op in aten_ops: _MASKEDTENSOR_DISPATCH_TABLE[aten_op] = partial(func, aten_op) return wrapper @register_dispatch_func(NATIVE_REDUCE_FNS + TORCH_REDUCE_FNS + TENSOR_REDUCE_FNS) def _general_reduction(func, *args, **kwargs): return _apply_reduction(func, *args, **kwargs) @register_dispatch_func(PASSTHROUGH_FNS) def _general_passthrough(func, *args, **kwargs): return _apply_pass_through_fn(func, *args, **kwargs) @register_dispatch_func(NATIVE_UNARY_FNS + NATIVE_INPLACE_UNARY_FNS) def _general_unary(func, *args, **kwargs): return _apply_native_unary(func, *args, **kwargs) @register_dispatch_func(NATIVE_BINARY_FNS + NATIVE_INPLACE_BINARY_FNS) def _general_binary(func, *args, **kwargs): return _apply_native_binary(func, *args, **kwargs) @register_dispatch_func([torch.ops.aten.stride]) def stride(func, *args, **kwargs): return None @register_dispatch_func([torch.ops.aten.sym_stride]) def sym_stride(func, *args, **kwargs): return None @register_dispatch_func([torch.ops.prim.layout]) def layout(func, *args, **kwargs): return _get_data(args[0]).layout @register_dispatch_func([torch.ops.aten.is_contiguous]) def is_contiguous(func, *args, **kwargs): data = _get_data(args[0]) if data.is_sparse: raise ValueError( "MaskedTensors with sparse data do not have is_contiguous" ) return func(data, *args[1:], **kwargs) @register_dispatch_func([torch.ops.aten.is_strides_like_format]) def is_strides_like_format(func, *args, **kwargs): data = _get_data(args[0]) if data.is_sparse: raise ValueError( "MaskedTensors with sparse data do not have is_strides_like_format" ) return func(data, *args[1:], **kwargs) @register_dispatch_func([torch.ops.aten.is_non_overlapping_and_dense]) def is_non_overlapping_and_dense(func, *args, **kwargs): data = _get_data(args[0]) if data.is_sparse: raise ValueError( "MaskedTensors with sparse data do not have is_non_overlapping_and_dense" ) return func(data, *args[1:], **kwargs) @register_dispatch_func([torch.ops.aten.contiguous]) def contiguous(func, *args, **kwargs): if _get_data(args[0]).is_sparse: raise ValueError( "MaskedTensors with sparse data do not have contiguous" ) return _MaskedContiguous.apply(args[0]) @register_dispatch_func([torch.ops.aten.new_empty_strided]) def new_empty_strided(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3) data = _get_data(args[0]) mask = _maybe_get_mask(args[0]) if tuple(args[1]) != tuple(data.size()): raise ValueError(f"__torch_dispatch__, {func}: args[1] expected to be the same as data.size()") if tuple(args[2]) != tuple(data.stride()): raise ValueError(f"__torch_dispatch__, {func}: args[2] expected to be the same as data.stride()") return MaskedTensor(func(data, args[1], args[2], **kwargs), mask) @register_dispatch_func([torch.ops.aten._local_scalar_dense]) def _local_scalar_dense(func, *args, **kwargs): if not _maybe_get_mask(args[0]): raise ValueError(f"__torch_dispatch__, {func}: expected a mask tensor") return torch.ops.aten._local_scalar_dense(_get_data(args[0])) @register_dispatch_func([torch.ops.aten.detach, torch.ops.aten.clone]) def _apply_fn_on_data(func, *args, **kwargs): return MaskedTensor(func(_get_data(args[0])), _maybe_get_mask(args[0])) @register_dispatch_func([torch.ops.aten._to_copy]) def _to_copy(func, *args, **kwargs): new_data = func(_get_data(args[0]), *args[1:], **kwargs) return MaskedTensor(new_data, _maybe_get_mask(args[0])) @register_dispatch_func([torch.ops.aten._softmax]) def _softmax(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0) data = _get_data(args[0]) mask = _maybe_get_mask(args[0]) result_data = torch.ops.aten._masked_softmax(data, ~mask, args[1], 2) return MaskedTensor(result_data, mask) @register_dispatch_func([torch.ops.aten.ones_like]) def ones_like(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1) result_data = func(_get_data(args[0]), **kwargs) return MaskedTensor(result_data, _maybe_get_mask(args[0])) @register_dispatch_func([torch.ops.aten._softmax_backward_data]) def _softmax_backward_data(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=4) grad, output, dim, input_dtype = args if is_masked_tensor(grad) and is_masked_tensor(output): if not _masks_match(grad, output): raise ValueError("__torch_dispatch__, {func}: expected the masks of grad and output to match") grad_data = _get_data(grad) new_grad_data = torch.ops.aten._masked_softmax_backward( grad_data, _get_data(output), ~_maybe_get_mask(grad), dim % grad_data.ndim, ) res = MaskedTensor(new_grad_data, _maybe_get_mask(grad)) return res else: raise ValueError(f"__torch_dispatch__, {func}: grad and output must both be MaskedTensors") @register_dispatch_func([torch.ops.aten.copy_]) def copy_(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2) if not _masks_match(_maybe_get_mask(args[0]), _maybe_get_mask(args[1])): raise ValueError("args[0] mask and args[1] mask must match but do not") func(_get_data(args[0]), _get_data(args[1])) return args[0] @register_dispatch_func([torch.ops.aten.where]) def where(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=3, len_kwargs=0) if not torch.is_tensor(args[0]): raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") mx = args[1] my = args[2] if not is_masked_tensor(mx): mx = MaskedTensor(mx, torch.ones_like(mx, dtype=torch.bool)) if not is_masked_tensor(my): my = MaskedTensor(my, torch.ones_like(my, dtype=torch.bool)) new_data = func(args[0], mx.get_data(), my.get_data()) new_mask = func(args[0], mx.get_mask(), my.get_mask()) return MaskedTensor(new_data, new_mask) @register_dispatch_func([torch.ops.aten.to_sparse]) def to_sparse(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) if not torch.is_tensor(args[0]): raise TypeError("__torch_dispatch__, {func}: expected args[0] to be a tensor") mt = args[0] if not is_masked_tensor(mt): mt = MaskedTensor(mt, torch.ones_like(mt, dtype=torch.bool)) if mt.is_sparse_coo(): return mt new_mask = func(_maybe_get_mask(args[0])).coalesce() new_data = _get_data(args[0]).sparse_mask(new_mask) return MaskedTensor(new_data, new_mask) @register_dispatch_func([torch.ops.aten.to_sparse_csr]) def to_sparse_csr(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) if not torch.is_tensor(args[0]): raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") mt = args[0] if not is_masked_tensor(mt): mt = MaskedTensor(mt, torch.ones_like(mt).bool()) if mt.is_sparse_csr(): return mt new_mask = func(_maybe_get_mask(args[0])) new_data = _get_data(args[0]).sparse_mask(new_mask) return MaskedTensor(new_data, new_mask) @register_dispatch_func([torch.ops.aten._to_dense]) def _to_dense(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) if not torch.is_tensor(args[0]): raise ValueError("__torch_dispatch__, {func}: expected args[0] to be a tensor") mt = args[0] if not is_masked_tensor(mt): mt = MaskedTensor(mt, torch.ones_like(mt).bool()) new_data = func(_get_data(args[0])) new_mask = func(_maybe_get_mask(args[0])) return MaskedTensor(new_data, new_mask) @register_dispatch_func([torch.ops.aten._indices]) def _indices(func, *args, **kwargs): # Assumes data is sparse _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) data = _get_data(args[0]).indices() return MaskedTensor(data, torch.ones_like(data).bool()) @register_dispatch_func([torch.ops.aten._values]) def _values(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=1, len_kwargs=0) data = _get_data(args[0]).values() return MaskedTensor(data, torch.ones_like(data).bool()) @register_dispatch_func([torch.ops.aten._sparse_coo_tensor_with_dims_and_tensors]) def _sparse_coo_tensor_with_dims_and_tensors(func, *args, **kwargs): new_args = list(args) if is_masked_tensor(args[-1]): new_args[-1] = args[-1].get_data() if is_masked_tensor(args[-2]): new_args[-2] = args[-2].get_data() new_data = func(*new_args, **kwargs) new_args[-1] = torch.ones_like(new_args[-1]) new_mask = func(*new_args, **kwargs).bool() return MaskedTensor(new_data, new_mask) @register_dispatch_func([torch.ops.aten.is_same_size]) def is_same_size(func, *args, **kwargs): _check_args_kwargs_length(args, kwargs, f"__torch_dispatch__, {func}", len_args=2) return _get_data(args[0]).is_same_size(_get_data(args[1]))