123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- # Copyright (c) Meta Platforms, Inc. and affiliates
- import torch
- from .core import _map_mt_args_kwargs, _masks_match, _tensors_match, _wrap_result, is_masked_tensor
- __all__ = [] # type: ignore[var-annotated]
- BINARY_NAMES = [
- "add",
- "atan2",
- "arctan2",
- "bitwise_and",
- "bitwise_or",
- "bitwise_xor",
- "bitwise_left_shift",
- "bitwise_right_shift",
- "div",
- "divide",
- "floor_divide",
- "fmod",
- "logaddexp",
- "logaddexp2",
- "mul",
- "multiply",
- "nextafter",
- "remainder",
- "sub",
- "subtract",
- "true_divide",
- "eq",
- "ne",
- "le",
- "ge",
- "greater",
- "greater_equal",
- "gt",
- "less_equal",
- "lt",
- "less",
- "maximum",
- "minimum",
- "fmax",
- "fmin",
- "not_equal",
- ]
- INPLACE_BINARY_NAMES = [
- n + "_"
- for n in (
- list(
- set(BINARY_NAMES)
- - {
- "logaddexp",
- "logaddexp2",
- "equal",
- "fmin",
- "minimum",
- "maximum",
- "fmax",
- }
- )
- )
- ]
- def _get_at_least_one_mask(a, b):
- if not is_masked_tensor(a) and not is_masked_tensor(b):
- raise TypeError("At least one of `a` and `b` must be a MaskedTensor")
- if not _masks_match(a, b):
- raise ValueError("a and b must have matching masks")
- if is_masked_tensor(a):
- return a.get_mask()
- return b.get_mask()
- def _binary_helper(fn, args, kwargs, inplace):
- if len(kwargs) != 0:
- raise ValueError("len(kwargs) must equal 0")
- for a in args[2:]:
- if torch.is_tensor(a):
- raise TypeError("MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs")
- if not _masks_match(*args[:2]):
- raise ValueError(
- "Input masks must match. If you need support for this, please open an issue on Github."
- )
- data_args, data_kwargs = _map_mt_args_kwargs(
- args, kwargs, lambda x: x.get_data()
- )
- mask_args, mask_kwargs = _map_mt_args_kwargs(
- args, kwargs, lambda x: x.get_mask()
- )
- args0_layout = data_args[0].layout
- same_layout = (
- (torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])) and
- (args0_layout == data_args[1].layout)
- )
- if args0_layout == torch.sparse_coo:
- if same_layout:
- if not _tensors_match(data_args[0].indices(), data_args[1].indices()):
- raise ValueError(
- "sparse_coo indices must match. If you need support for this, please open an issue on Github."
- )
- if data_args[0].size() != data_args[1].size():
- raise ValueError("input1 and input2 must have the same size for binary functions.")
- data_args[1] = data_args[1].values()
- i = data_args[0].indices()
- size = data_args[0].size()
- data_args[0] = data_args[0].values()
- v = fn(*data_args)
- result_data = torch.sparse_coo_tensor(i, v, size)
- elif args0_layout == torch.sparse_csr:
- if same_layout:
- if not (
- _tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices())
- and _tensors_match(
- data_args[0].col_indices(), data_args[1].col_indices()
- )
- ):
- raise ValueError(
- "sparse_csr indices must match. If you need support for this, please open an issue on Github."
- )
- data_args[1] = data_args[1].values()
- crow = data_args[0].crow_indices()
- col = data_args[0].col_indices()
- data_args[0] = data_args[0].values()
- v = fn(*data_args)
- result_data = torch.sparse_csr_tensor(crow, col, v)
- else:
- result_data = fn(*data_args)
- if inplace:
- args[0]._set_data_mask(result_data, mask_args[0])
- return args[0]
- else:
- result_mask = _get_at_least_one_mask(*args[:2])
- # sparse tensors don't have strides so we can only expand if the layout is strided
- if args0_layout == torch.strided:
- result_mask = result_mask.expand_as(result_data)
- return _wrap_result(result_data, result_mask)
- def _torch_binary(fn_name):
- fn = getattr(torch.ops.aten, fn_name)
- def binary_fn(*args, **kwargs):
- return _binary_helper(fn, args, kwargs, inplace=False)
- return binary_fn
- def _torch_inplace_binary(fn_name):
- fn = getattr(torch.ops.aten, fn_name)
- def binary_fn(*args, **kwargs):
- return _binary_helper(fn, args, kwargs, inplace=True)
- return binary_fn
- NATIVE_BINARY_MAP = {
- getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES
- }
- NATIVE_INPLACE_BINARY_MAP = {
- getattr(torch.ops.aten, name): _torch_inplace_binary(name)
- for name in INPLACE_BINARY_NAMES
- }
- NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys())
- NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys())
- def _is_native_binary(fn):
- return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS
- def _apply_native_binary(fn, *args, **kwargs):
- if fn in NATIVE_BINARY_FNS:
- return NATIVE_BINARY_MAP[fn](*args, **kwargs)
- if fn in NATIVE_INPLACE_BINARY_FNS:
- return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs)
- return NotImplemented
|