binary.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. import torch
  3. from .core import _map_mt_args_kwargs, _masks_match, _tensors_match, _wrap_result, is_masked_tensor
  4. __all__ = [] # type: ignore[var-annotated]
  5. BINARY_NAMES = [
  6. "add",
  7. "atan2",
  8. "arctan2",
  9. "bitwise_and",
  10. "bitwise_or",
  11. "bitwise_xor",
  12. "bitwise_left_shift",
  13. "bitwise_right_shift",
  14. "div",
  15. "divide",
  16. "floor_divide",
  17. "fmod",
  18. "logaddexp",
  19. "logaddexp2",
  20. "mul",
  21. "multiply",
  22. "nextafter",
  23. "remainder",
  24. "sub",
  25. "subtract",
  26. "true_divide",
  27. "eq",
  28. "ne",
  29. "le",
  30. "ge",
  31. "greater",
  32. "greater_equal",
  33. "gt",
  34. "less_equal",
  35. "lt",
  36. "less",
  37. "maximum",
  38. "minimum",
  39. "fmax",
  40. "fmin",
  41. "not_equal",
  42. ]
  43. INPLACE_BINARY_NAMES = [
  44. n + "_"
  45. for n in (
  46. list(
  47. set(BINARY_NAMES)
  48. - {
  49. "logaddexp",
  50. "logaddexp2",
  51. "equal",
  52. "fmin",
  53. "minimum",
  54. "maximum",
  55. "fmax",
  56. }
  57. )
  58. )
  59. ]
  60. def _get_at_least_one_mask(a, b):
  61. if not is_masked_tensor(a) and not is_masked_tensor(b):
  62. raise TypeError("At least one of `a` and `b` must be a MaskedTensor")
  63. if not _masks_match(a, b):
  64. raise ValueError("a and b must have matching masks")
  65. if is_masked_tensor(a):
  66. return a.get_mask()
  67. return b.get_mask()
  68. def _binary_helper(fn, args, kwargs, inplace):
  69. if len(kwargs) != 0:
  70. raise ValueError("len(kwargs) must equal 0")
  71. for a in args[2:]:
  72. if torch.is_tensor(a):
  73. raise TypeError("MaskedTensor binary ops do not support Tensor arguments aside from the lhs and rhs")
  74. if not _masks_match(*args[:2]):
  75. raise ValueError(
  76. "Input masks must match. If you need support for this, please open an issue on Github."
  77. )
  78. data_args, data_kwargs = _map_mt_args_kwargs(
  79. args, kwargs, lambda x: x.get_data()
  80. )
  81. mask_args, mask_kwargs = _map_mt_args_kwargs(
  82. args, kwargs, lambda x: x.get_mask()
  83. )
  84. args0_layout = data_args[0].layout
  85. same_layout = (
  86. (torch.is_tensor(data_args[1]) or is_masked_tensor(data_args[1])) and
  87. (args0_layout == data_args[1].layout)
  88. )
  89. if args0_layout == torch.sparse_coo:
  90. if same_layout:
  91. if not _tensors_match(data_args[0].indices(), data_args[1].indices()):
  92. raise ValueError(
  93. "sparse_coo indices must match. If you need support for this, please open an issue on Github."
  94. )
  95. if data_args[0].size() != data_args[1].size():
  96. raise ValueError("input1 and input2 must have the same size for binary functions.")
  97. data_args[1] = data_args[1].values()
  98. i = data_args[0].indices()
  99. size = data_args[0].size()
  100. data_args[0] = data_args[0].values()
  101. v = fn(*data_args)
  102. result_data = torch.sparse_coo_tensor(i, v, size)
  103. elif args0_layout == torch.sparse_csr:
  104. if same_layout:
  105. if not (
  106. _tensors_match(data_args[0].crow_indices(), data_args[1].crow_indices())
  107. and _tensors_match(
  108. data_args[0].col_indices(), data_args[1].col_indices()
  109. )
  110. ):
  111. raise ValueError(
  112. "sparse_csr indices must match. If you need support for this, please open an issue on Github."
  113. )
  114. data_args[1] = data_args[1].values()
  115. crow = data_args[0].crow_indices()
  116. col = data_args[0].col_indices()
  117. data_args[0] = data_args[0].values()
  118. v = fn(*data_args)
  119. result_data = torch.sparse_csr_tensor(crow, col, v)
  120. else:
  121. result_data = fn(*data_args)
  122. if inplace:
  123. args[0]._set_data_mask(result_data, mask_args[0])
  124. return args[0]
  125. else:
  126. result_mask = _get_at_least_one_mask(*args[:2])
  127. # sparse tensors don't have strides so we can only expand if the layout is strided
  128. if args0_layout == torch.strided:
  129. result_mask = result_mask.expand_as(result_data)
  130. return _wrap_result(result_data, result_mask)
  131. def _torch_binary(fn_name):
  132. fn = getattr(torch.ops.aten, fn_name)
  133. def binary_fn(*args, **kwargs):
  134. return _binary_helper(fn, args, kwargs, inplace=False)
  135. return binary_fn
  136. def _torch_inplace_binary(fn_name):
  137. fn = getattr(torch.ops.aten, fn_name)
  138. def binary_fn(*args, **kwargs):
  139. return _binary_helper(fn, args, kwargs, inplace=True)
  140. return binary_fn
  141. NATIVE_BINARY_MAP = {
  142. getattr(torch.ops.aten, name): _torch_binary(name) for name in BINARY_NAMES
  143. }
  144. NATIVE_INPLACE_BINARY_MAP = {
  145. getattr(torch.ops.aten, name): _torch_inplace_binary(name)
  146. for name in INPLACE_BINARY_NAMES
  147. }
  148. NATIVE_BINARY_FNS = list(NATIVE_BINARY_MAP.keys())
  149. NATIVE_INPLACE_BINARY_FNS = list(NATIVE_INPLACE_BINARY_MAP.keys())
  150. def _is_native_binary(fn):
  151. return fn in NATIVE_BINARY_FNS or fn in NATIVE_INPLACE_BINARY_FNS
  152. def _apply_native_binary(fn, *args, **kwargs):
  153. if fn in NATIVE_BINARY_FNS:
  154. return NATIVE_BINARY_MAP[fn](*args, **kwargs)
  155. if fn in NATIVE_INPLACE_BINARY_FNS:
  156. return NATIVE_INPLACE_BINARY_MAP[fn](*args, **kwargs)
  157. return NotImplemented