__init__.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import math
  2. from typing import Optional, Union
  3. import torch
  4. import torch._prims as prims
  5. import torch._prims_common as utils
  6. import torch._refs as refs
  7. from torch import Tensor
  8. from torch._decomp import register_decomposition
  9. from torch._prims_common import (
  10. ELEMENTWISE_TYPE_PROMOTION_KIND,
  11. Number,
  12. NumberType,
  13. TensorLike,
  14. TensorLikeType,
  15. )
  16. from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper
  17. from torch._refs import (
  18. _make_elementwise_binary_reference,
  19. _make_elementwise_unary_reference,
  20. )
  21. __all__ = [
  22. "bessel_j0",
  23. "bessel_j1",
  24. "entr",
  25. "erfcx",
  26. "expit",
  27. "i0e",
  28. "i1",
  29. "i1e",
  30. "log_ndtr",
  31. "logit",
  32. "log_softmax",
  33. "multigammaln",
  34. "ndtr",
  35. "ndtri",
  36. "softmax",
  37. "spherical_bessel_j0",
  38. "xlog1py",
  39. "zeta",
  40. ]
  41. aten = torch._ops.ops.aten
  42. @_make_elementwise_unary_reference(
  43. ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  44. )
  45. def bessel_j0(a: TensorLikeType) -> TensorLikeType:
  46. return prims.bessel_j0(a)
  47. @_make_elementwise_unary_reference(
  48. ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  49. )
  50. def bessel_j1(a: TensorLikeType) -> TensorLikeType:
  51. return prims.bessel_j1(a)
  52. @register_decomposition(aten.special_entr)
  53. @out_wrapper()
  54. @elementwise_type_promotion_wrapper(
  55. type_promoting_args=("a",),
  56. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  57. )
  58. def entr(a: TensorLikeType) -> TensorLikeType:
  59. return torch.where(
  60. torch.isnan(a),
  61. a,
  62. torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)),
  63. )
  64. @register_decomposition(aten.special_erfcx)
  65. @out_wrapper()
  66. @elementwise_type_promotion_wrapper(
  67. type_promoting_args=("a",),
  68. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  69. )
  70. def erfcx(a: TensorLikeType) -> TensorLikeType:
  71. return prims.erfcx(a)
  72. # alias for sigmoid
  73. expit = torch.sigmoid
  74. @_make_elementwise_unary_reference(
  75. ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  76. )
  77. def i0e(a: TensorLikeType) -> TensorLikeType:
  78. return prims.bessel_i0e(a)
  79. @_make_elementwise_unary_reference(
  80. ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  81. )
  82. def i1(a: TensorLikeType) -> TensorLikeType:
  83. return prims.bessel_i1(a)
  84. @_make_elementwise_unary_reference(
  85. ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  86. )
  87. def i1e(a: TensorLikeType) -> TensorLikeType:
  88. return prims.bessel_i1e(a)
  89. @register_decomposition(aten.special_log_ndtr)
  90. @out_wrapper()
  91. @elementwise_type_promotion_wrapper(
  92. type_promoting_args=("a",),
  93. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  94. )
  95. def log_ndtr(a: TensorLikeType) -> TensorLikeType:
  96. # Note: M_SQRT1_2 is the value of 1 / √2
  97. M_SQRT1_2 = 0.707106781186547524400844362104849039
  98. t = a * M_SQRT1_2
  99. return torch.where(
  100. a < 1.0,
  101. torch.log(torch.special.erfcx(-t) / 2) - t * t,
  102. torch.log1p(-refs.erfc(t) / 2),
  103. )
  104. @register_decomposition(aten.logit)
  105. @out_wrapper()
  106. @elementwise_type_promotion_wrapper(
  107. type_promoting_args=("self",),
  108. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  109. )
  110. def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
  111. if eps is None:
  112. eps = -1.0
  113. lo = eps
  114. hi = 1 - eps
  115. self = torch.clamp(self, lo, hi)
  116. return torch.log(torch.true_divide(self, torch.sub(1, self)))
  117. @register_decomposition(aten.special_xlog1py)
  118. @out_wrapper()
  119. @elementwise_type_promotion_wrapper(
  120. type_promoting_args=("a", "b"),
  121. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  122. )
  123. def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
  124. utils.check(
  125. isinstance(a, TensorLike) or isinstance(b, TensorLike),
  126. lambda: 'Expected either argument a or b to be a Tensor"',
  127. )
  128. # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
  129. if isinstance(a, TensorLike) and isinstance(b, Number):
  130. b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device)
  131. elif isinstance(b, TensorLike) and isinstance(a, Number):
  132. a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device)
  133. # mypy: expected "Tensor"
  134. assert isinstance(a, TensorLike)
  135. assert isinstance(b, TensorLike)
  136. rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, refs.log1p(b)))
  137. return torch.where(torch.isnan(b), float("nan"), rhs)
  138. @register_decomposition(aten.mvlgamma)
  139. @out_wrapper()
  140. @elementwise_type_promotion_wrapper(
  141. type_promoting_args=("a",),
  142. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  143. )
  144. def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType:
  145. c = 0.25 * p * (p - 1) * math.log(math.pi)
  146. b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device)
  147. return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c
  148. @register_decomposition(aten.special_ndtr)
  149. @out_wrapper()
  150. @elementwise_type_promotion_wrapper(
  151. type_promoting_args=("a",),
  152. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  153. )
  154. def ndtr(a: TensorLikeType) -> TensorLikeType:
  155. # Note: M_SQRT1_2 is the value of 1 / √2
  156. M_SQRT1_2 = 0.707106781186547524400844362104849039
  157. a_sqrt_2 = a * M_SQRT1_2
  158. return (1 + torch.erf(a_sqrt_2)) * 0.5
  159. @register_decomposition(aten.special_ndtri)
  160. @out_wrapper()
  161. @elementwise_type_promotion_wrapper(
  162. type_promoting_args=("a",),
  163. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  164. )
  165. def ndtri(a: TensorLikeType) -> TensorLikeType:
  166. return prims.ndtri(a)
  167. # Forwarding alias: the special variant doesn't support the out kwarg
  168. # CompositeImplicitAutograd - don't register decomp
  169. def log_softmax(
  170. a: TensorLikeType,
  171. dim: int,
  172. dtype: Optional[torch.dtype] = None,
  173. ) -> TensorLikeType:
  174. return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
  175. # Forwarding alias: the special variant doesn't support the out kwarg
  176. # CompositeImplicitAutograd - don't register decomp
  177. def softmax(
  178. a: TensorLikeType,
  179. dim: int,
  180. dtype: Optional[torch.dtype] = None,
  181. ) -> TensorLikeType:
  182. return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
  183. @_make_elementwise_unary_reference(
  184. ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  185. )
  186. def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType:
  187. return prims.spherical_bessel_j0(a)
  188. # TODO: add docstring
  189. @_make_elementwise_binary_reference(
  190. type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  191. )
  192. def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
  193. return prims.zeta(a, b)