unary.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. import torch
  3. from .core import _map_mt_args_kwargs, _wrap_result
  4. __all__ = [] # type: ignore[var-annotated]
  5. UNARY_NAMES = [
  6. "abs",
  7. "absolute",
  8. "acos",
  9. "arccos",
  10. "acosh",
  11. "arccosh",
  12. "angle",
  13. "asin",
  14. "arcsin",
  15. "asinh",
  16. "arcsinh",
  17. "atan",
  18. "arctan",
  19. "atanh",
  20. "arctanh",
  21. "bitwise_not",
  22. "ceil",
  23. "clamp",
  24. "clip",
  25. "conj_physical",
  26. "cos",
  27. "cosh",
  28. "deg2rad",
  29. "digamma",
  30. "erf",
  31. "erfc",
  32. "erfinv",
  33. "exp",
  34. "exp2",
  35. "expm1",
  36. "fix",
  37. "floor",
  38. "frac",
  39. "lgamma",
  40. "log",
  41. "log10",
  42. "log1p",
  43. "log2",
  44. "logit",
  45. "i0",
  46. "isnan",
  47. "nan_to_num",
  48. "neg",
  49. "negative",
  50. "positive",
  51. "pow",
  52. "rad2deg",
  53. "reciprocal",
  54. "round",
  55. "rsqrt",
  56. "sigmoid",
  57. "sign",
  58. "sgn",
  59. "signbit",
  60. "sin",
  61. "sinc",
  62. "sinh",
  63. "sqrt",
  64. "square",
  65. "tan",
  66. "tanh",
  67. "trunc",
  68. ]
  69. INPLACE_UNARY_NAMES = [
  70. n + "_"
  71. for n in (list(set(UNARY_NAMES) - {"angle", "positive", "signbit", "isnan"}))
  72. ]
  73. # Explicitly tracking functions we know are currently not supported
  74. # This might be due to missing code gen or because of complex semantics
  75. UNARY_NAMES_UNSUPPORTED = [
  76. "atan2",
  77. "arctan2",
  78. "bitwise_left_shift",
  79. "bitwise_right_shift",
  80. "copysign",
  81. "float_power",
  82. "fmod",
  83. "frexp",
  84. "gradient",
  85. "imag",
  86. "ldexp",
  87. "lerp",
  88. "logical_not",
  89. "hypot",
  90. "igamma",
  91. "igammac",
  92. "mvlgamma",
  93. "nextafter",
  94. "polygamma",
  95. "real",
  96. "remainder",
  97. "true_divide",
  98. "xlogy",
  99. ]
  100. def _unary_helper(fn, args, kwargs, inplace):
  101. if len(kwargs) != 0:
  102. raise ValueError("MaskedTensor unary ops require that len(kwargs) == 0. "
  103. "If you need support for this, please open an issue on Github.")
  104. for a in args[1:]:
  105. if torch.is_tensor(a):
  106. raise TypeError("MaskedTensor unary ops do not support additional Tensor arguments")
  107. mask_args, mask_kwargs = _map_mt_args_kwargs(
  108. args, kwargs, lambda x: x._masked_mask
  109. )
  110. data_args, data_kwargs = _map_mt_args_kwargs(
  111. args, kwargs, lambda x: x._masked_data
  112. )
  113. if args[0].layout == torch.sparse_coo:
  114. data_args[0] = data_args[0].coalesce()
  115. s = data_args[0].size()
  116. i = data_args[0].indices()
  117. data_args[0] = data_args[0].coalesce().values()
  118. v = fn(*data_args)
  119. result_data = torch.sparse_coo_tensor(i, v, size=s)
  120. elif args[0].layout == torch.sparse_csr:
  121. crow = data_args[0].crow_indices()
  122. col = data_args[0].col_indices()
  123. data_args[0] = data_args[0].values()
  124. v = fn(*data_args)
  125. result_data = torch.sparse_csr_tensor(crow, col, v)
  126. else:
  127. result_data = fn(*data_args)
  128. if inplace:
  129. args[0]._set_data_mask(result_data, mask_args[0])
  130. return args[0]
  131. else:
  132. return _wrap_result(result_data, mask_args[0])
  133. def _torch_unary(fn_name):
  134. fn = getattr(torch.ops.aten, fn_name)
  135. def unary_fn(*args, **kwargs):
  136. return _unary_helper(fn, args, kwargs, inplace=False)
  137. return unary_fn
  138. def _torch_inplace_unary(fn_name):
  139. fn = getattr(torch.ops.aten, fn_name)
  140. def unary_fn(*args, **kwargs):
  141. return _unary_helper(fn, args, kwargs, inplace=True)
  142. return unary_fn
  143. NATIVE_UNARY_MAP = {
  144. getattr(torch.ops.aten, name): _torch_unary(name) for name in UNARY_NAMES
  145. }
  146. NATIVE_INPLACE_UNARY_MAP = {
  147. getattr(torch.ops.aten, name): _torch_inplace_unary(name)
  148. for name in INPLACE_UNARY_NAMES
  149. }
  150. NATIVE_UNARY_FNS = list(NATIVE_UNARY_MAP.keys())
  151. NATIVE_INPLACE_UNARY_FNS = list(NATIVE_INPLACE_UNARY_MAP.keys())
  152. def _is_native_unary(fn):
  153. return fn in NATIVE_UNARY_FNS or fn in NATIVE_INPLACE_UNARY_FNS
  154. def _apply_native_unary(fn, *args, **kwargs):
  155. if fn in NATIVE_UNARY_FNS:
  156. return NATIVE_UNARY_MAP[fn](*args, **kwargs)
  157. if fn in NATIVE_INPLACE_UNARY_FNS:
  158. return NATIVE_INPLACE_UNARY_MAP[fn](*args, **kwargs)
  159. return NotImplemented