123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235 |
- import math
- from typing import Optional, Union
- import torch
- import torch._prims as prims
- import torch._prims_common as utils
- import torch._refs as refs
- from torch import Tensor
- from torch._decomp import register_decomposition
- from torch._prims_common import (
- ELEMENTWISE_TYPE_PROMOTION_KIND,
- Number,
- NumberType,
- TensorLike,
- TensorLikeType,
- )
- from torch._prims_common.wrappers import elementwise_type_promotion_wrapper, out_wrapper
- from torch._refs import (
- _make_elementwise_binary_reference,
- _make_elementwise_unary_reference,
- )
- __all__ = [
- "bessel_j0",
- "bessel_j1",
- "entr",
- "erfcx",
- "expit",
- "i0e",
- "i1",
- "i1e",
- "log_ndtr",
- "logit",
- "log_softmax",
- "multigammaln",
- "ndtr",
- "ndtri",
- "softmax",
- "spherical_bessel_j0",
- "xlog1py",
- "zeta",
- ]
- aten = torch._ops.ops.aten
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def bessel_j0(a: TensorLikeType) -> TensorLikeType:
- return prims.bessel_j0(a)
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def bessel_j1(a: TensorLikeType) -> TensorLikeType:
- return prims.bessel_j1(a)
- @register_decomposition(aten.special_entr)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a",),
- type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def entr(a: TensorLikeType) -> TensorLikeType:
- return torch.where(
- torch.isnan(a),
- a,
- torch.where(a > 0, -a * torch.log(a), torch.where(a == 0, 0, -torch.inf)),
- )
- @register_decomposition(aten.special_erfcx)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a",),
- type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def erfcx(a: TensorLikeType) -> TensorLikeType:
- return prims.erfcx(a)
- # alias for sigmoid
- expit = torch.sigmoid
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def i0e(a: TensorLikeType) -> TensorLikeType:
- return prims.bessel_i0e(a)
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def i1(a: TensorLikeType) -> TensorLikeType:
- return prims.bessel_i1(a)
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def i1e(a: TensorLikeType) -> TensorLikeType:
- return prims.bessel_i1e(a)
- @register_decomposition(aten.special_log_ndtr)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a",),
- type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def log_ndtr(a: TensorLikeType) -> TensorLikeType:
- # Note: M_SQRT1_2 is the value of 1 / √2
- M_SQRT1_2 = 0.707106781186547524400844362104849039
- t = a * M_SQRT1_2
- return torch.where(
- a < 1.0,
- torch.log(torch.special.erfcx(-t) / 2) - t * t,
- torch.log1p(-refs.erfc(t) / 2),
- )
- @register_decomposition(aten.logit)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("self",),
- type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def logit(self: TensorLikeType, eps: Optional[float] = None) -> TensorLikeType:
- if eps is None:
- eps = -1.0
- lo = eps
- hi = 1 - eps
- self = torch.clamp(self, lo, hi)
- return torch.log(torch.true_divide(self, torch.sub(1, self)))
- @register_decomposition(aten.special_xlog1py)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a", "b"),
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]):
- utils.check(
- isinstance(a, TensorLike) or isinstance(b, TensorLike),
- lambda: 'Expected either argument a or b to be a Tensor"',
- )
- # Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
- if isinstance(a, TensorLike) and isinstance(b, Number):
- b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device)
- elif isinstance(b, TensorLike) and isinstance(a, Number):
- a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device)
- # mypy: expected "Tensor"
- assert isinstance(a, TensorLike)
- assert isinstance(b, TensorLike)
- rhs = torch.where(torch.eq(a, 0), 0, torch.mul(a, refs.log1p(b)))
- return torch.where(torch.isnan(b), float("nan"), rhs)
- @register_decomposition(aten.mvlgamma)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a",),
- type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def multigammaln(a: TensorLikeType, p: int) -> TensorLikeType:
- c = 0.25 * p * (p - 1) * math.log(math.pi)
- b = 0.5 * torch.arange(start=(1 - p), end=1, step=1, dtype=a.dtype, device=a.device)
- return torch.sum(torch.lgamma(a.unsqueeze(-1) + b), dim=-1) + c
- @register_decomposition(aten.special_ndtr)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a",),
- type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def ndtr(a: TensorLikeType) -> TensorLikeType:
- # Note: M_SQRT1_2 is the value of 1 / √2
- M_SQRT1_2 = 0.707106781186547524400844362104849039
- a_sqrt_2 = a * M_SQRT1_2
- return (1 + torch.erf(a_sqrt_2)) * 0.5
- @register_decomposition(aten.special_ndtri)
- @out_wrapper()
- @elementwise_type_promotion_wrapper(
- type_promoting_args=("a",),
- type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def ndtri(a: TensorLikeType) -> TensorLikeType:
- return prims.ndtri(a)
- # Forwarding alias: the special variant doesn't support the out kwarg
- # CompositeImplicitAutograd - don't register decomp
- def log_softmax(
- a: TensorLikeType,
- dim: int,
- dtype: Optional[torch.dtype] = None,
- ) -> TensorLikeType:
- return torch.log_softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
- # Forwarding alias: the special variant doesn't support the out kwarg
- # CompositeImplicitAutograd - don't register decomp
- def softmax(
- a: TensorLikeType,
- dim: int,
- dtype: Optional[torch.dtype] = None,
- ) -> TensorLikeType:
- return torch.softmax(a=a, dim=dim, dtype=dtype) # type: ignore[call-overload]
- @_make_elementwise_unary_reference(
- ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def spherical_bessel_j0(a: TensorLikeType) -> TensorLikeType:
- return prims.spherical_bessel_j0(a)
- # TODO: add docstring
- @_make_elementwise_binary_reference(
- type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def zeta(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
- return prims.zeta(a, b)
|