123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106 |
- import torch
- import torch._prims_common as utils
- # Utilities should come BEFORE this import
- from torch._decomp import register_decomposition
- from torch._prims_common import check, TensorLikeType
- from torch._prims_common.wrappers import out_wrapper
- from torch._refs import _broadcast_shapes
- # Data conversion references.
- #
- # Note: this module breaks the usual _refs to torch naming scheme where
- # _refs.foo.bar is a ref for torch.foo.bar. The following definitions are not
- # part of _refs/__init__.py to avoid name clashes with Python builtin types
- # (like int).
- __all__ = [
- # dtypes
- "bfloat16",
- "bool",
- "byte",
- "cdouble",
- "cfloat",
- "chalf",
- "char",
- "double",
- "float",
- "half",
- "int",
- "long",
- "short",
- # misc
- "complex",
- ]
- def _make_conversion_method(name: str, dtype: torch.dtype):
- def fn(
- self: TensorLikeType, memory_format: torch.memory_format = torch.preserve_format
- ) -> TensorLikeType:
- return self.to(dtype, memory_format=memory_format) # type: ignore[call-overload]
- fn.__name__ = name
- return fn
- bfloat16 = _make_conversion_method("bfloat16", torch.bfloat16)
- bool = _make_conversion_method("bool", torch.bool)
- byte = _make_conversion_method("byte", torch.uint8)
- cdouble = _make_conversion_method("cdouble", torch.cdouble)
- cfloat = _make_conversion_method("cfloat", torch.cfloat)
- chalf = _make_conversion_method("chalf", torch.complex32)
- char = _make_conversion_method("char", torch.int8)
- double = _make_conversion_method("double", torch.double)
- float = _make_conversion_method("float", torch.float)
- half = _make_conversion_method("half", torch.half)
- int = _make_conversion_method("int", torch.int)
- long = _make_conversion_method("long", torch.long)
- short = _make_conversion_method("short", torch.short)
- @register_decomposition(torch._ops.ops.aten.complex)
- # Note: complex has type promotion tests disabled due to different semantics.
- # exact_dtype is for compat with complex_check_dtype from core.
- @out_wrapper(exact_dtype=True)
- def complex(real: TensorLikeType, imag: TensorLikeType) -> TensorLikeType:
- allowed_dtypes = (torch.float32, torch.float64, torch.float16)
- check(
- real.dtype in allowed_dtypes and imag.dtype in allowed_dtypes,
- lambda: (
- f"Expected both inputs to be Half, Float or Double tensors but got "
- f"{real.dtype} and {imag.dtype}"
- ),
- )
- check(
- real.dtype == imag.dtype,
- lambda: (
- f"Expected object of scalar type {real.dtype} but got "
- f"scalar type {imag.dtype} for second argument"
- ),
- )
- result_dtype = utils.corresponding_complex_dtype(real.dtype) # type: ignore[arg-type]
- common_shape = _broadcast_shapes(real.shape, imag.shape)
- result = real.new_empty(
- common_shape,
- dtype=result_dtype,
- layout=real.layout,
- device=real.device,
- # pin_memory=real.is_pinned(), # NYI
- )
- result.real = real
- result.imag = imag
- return result
|