__init__.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. from functools import partial
  2. from typing import List, Optional, Tuple, Union
  3. import torch
  4. import torch._prims as prims
  5. import torch._prims_common as utils
  6. import torch._refs as refs
  7. import torch._refs.linalg as linalg
  8. from torch import Tensor
  9. from torch._prims_common import (
  10. check,
  11. check_fp_or_complex,
  12. check_is_matrix,
  13. Dim,
  14. DimsType,
  15. NumberType,
  16. TensorLikeType,
  17. )
  18. from torch._prims_common.wrappers import _maybe_convert_to_dtype, out_wrapper
  19. __all__ = [
  20. "svd",
  21. "vector_norm",
  22. "matrix_norm",
  23. "norm",
  24. ]
  25. def check_norm_dtype(dtype: Optional[torch.dtype], x_dtype: torch.dtype, fn_name: str):
  26. """
  27. Checks related to the dtype kwarg in `linalg.*norm` functions
  28. """
  29. if dtype is not None:
  30. check(
  31. utils.is_float_dtype(dtype) or utils.is_complex_dtype(dtype),
  32. lambda: f"{fn_name}: dtype should be floating point or complex. Got {dtype}",
  33. )
  34. check(
  35. utils.is_complex_dtype(dtype) == utils.is_complex_dtype(x_dtype),
  36. lambda: "{fn_name}: dtype should be {d} for {d} inputs. Got {dtype}".format(
  37. fn_name=fn_name,
  38. d="complex" if utils.is_complex_dtype(x_dtype) else "real",
  39. dtype=dtype,
  40. ),
  41. )
  42. check(
  43. utils.get_higher_dtype(dtype, x_dtype) == dtype,
  44. lambda: f"{fn_name}: the dtype of the input ({x_dtype}) should be convertible "
  45. "without narrowing to the specified dtype ({dtype})",
  46. )
  47. # Utilities should come BEFORE this import
  48. from torch._decomp import register_decomposition
  49. @register_decomposition(torch._ops.ops.aten.linalg_vector_norm)
  50. @out_wrapper(exact_dtype=True)
  51. def vector_norm(
  52. x: TensorLikeType,
  53. ord: float = 2.0,
  54. dim: Optional[DimsType] = None,
  55. keepdim: bool = False,
  56. *,
  57. dtype: Optional[torch.dtype] = None,
  58. ) -> Tensor:
  59. # Checks
  60. check_fp_or_complex(x.dtype, "linalg.vector_norm")
  61. if isinstance(dim, Dim):
  62. dim = [dim] # type: ignore[assignment]
  63. if x.numel() == 0 and (ord < 0.0 or ord == float("inf")):
  64. check(
  65. dim is not None and len(dim) != 0,
  66. lambda: f"linalg.vector_norm cannot compute the {ord} norm on an empty tensor "
  67. "because the operation does not have an identity",
  68. )
  69. shape = x.shape
  70. assert dim is not None # mypy does not seem to be able to see through check?
  71. for d in dim:
  72. check(
  73. shape[d] != 0,
  74. lambda: f"linalg.vector_norm cannot compute the {ord} norm on the "
  75. f"dimension {d} because this dimension is empty and the "
  76. "operation does not have an identity",
  77. )
  78. check_norm_dtype(dtype, x.dtype, "linalg.vector_norm")
  79. computation_dtype, result_dtype = utils.reduction_dtypes(
  80. x, utils.REDUCTION_OUTPUT_TYPE_KIND.COMPLEX_TO_FLOAT, dtype
  81. )
  82. to_result_dtype = partial(_maybe_convert_to_dtype, dtype=result_dtype)
  83. # Implementation
  84. if ord == 0.0:
  85. return torch.sum(torch.ne(x, 0.0), dim=dim, keepdim=keepdim, dtype=result_dtype)
  86. elif ord == float("inf"):
  87. return to_result_dtype(torch.amax(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type]
  88. elif ord == float("-inf"):
  89. return to_result_dtype(torch.amin(torch.abs(x), dim=dim, keepdim=keepdim)) # type: ignore[return-value,arg-type]
  90. else:
  91. # From here on the computation dtype is important as the reduction is non-trivial
  92. x = _maybe_convert_to_dtype(x, computation_dtype) # type: ignore[assignment]
  93. reduce_sum = partial(torch.sum, dim=dim, keepdim=keepdim)
  94. if not (ord % 2.0 == 0.0 and utils.is_float_dtype(x.dtype)):
  95. x = torch.abs(x)
  96. return to_result_dtype(torch.pow(reduce_sum(torch.pow(x, ord)), 1.0 / ord)) # type: ignore[return-value]
  97. def backshift_permutation(dim0, dim1, ndim):
  98. # Auxiliary function for matrix_norm
  99. # Computes the permutation that moves the two given dimensions to the back
  100. ret = [i for i in range(ndim) if i != dim0 and i != dim1]
  101. ret.extend((dim0, dim1))
  102. return ret
  103. def inverse_permutation(perm):
  104. # Given a permutation, returns its inverse. It's equivalent to argsort on an array
  105. return [i for i, j in sorted(enumerate(perm), key=lambda i_j: i_j[1])]
  106. # CompositeImplicitAutograd
  107. @out_wrapper(exact_dtype=True)
  108. def matrix_norm(
  109. A: TensorLikeType,
  110. ord: Union[float, str] = "fro",
  111. dim: DimsType = (-2, -1),
  112. keepdim: bool = False,
  113. *,
  114. dtype: Optional[torch.dtype] = None,
  115. ) -> TensorLikeType:
  116. # shape
  117. check_is_matrix(A, "linalg.matrix_norm")
  118. # dim
  119. dim = utils.canonicalize_dims(A.ndim, dim)
  120. if isinstance(dim, Dim):
  121. dim = (dim,) # type: ignore[assignment]
  122. check(len(dim) == 2, lambda: "linalg.matrix_norm: dim must be a 2-tuple. Got {dim}")
  123. check(
  124. dim[0] != dim[1],
  125. lambda: "linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
  126. )
  127. # dtype arg
  128. check_norm_dtype(dtype, A.dtype, "linalg.matrix_norm")
  129. if isinstance(ord, str):
  130. # ord
  131. check(
  132. ord in ("fro", "nuc"),
  133. lambda: "linalg.matrix_norm: Order {ord} not supported.",
  134. )
  135. # dtype
  136. check_fp_or_complex(
  137. A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != "nuc"
  138. )
  139. if ord == "fro":
  140. return vector_norm(A, 2, dim, keepdim, dtype=dtype)
  141. else: # ord == "nuc"
  142. if dtype is not None:
  143. A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
  144. perm = backshift_permutation(dim[0], dim[1], A.ndim)
  145. result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim)
  146. if keepdim:
  147. inv_perm = inverse_permutation(perm)
  148. result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
  149. return result
  150. else:
  151. # ord
  152. abs_ord = abs(ord)
  153. check(
  154. abs_ord in (2, 1, float("inf")),
  155. lambda: "linalg.matrix_norm: Order {ord} not supported.",
  156. )
  157. # dtype
  158. check_fp_or_complex(
  159. A.dtype, "linalg.matrix_norm", allow_low_precision_dtypes=ord != 2
  160. )
  161. max_min = partial(torch.amax if ord > 0.0 else torch.amin, keepdim=keepdim)
  162. if abs_ord == 2.0:
  163. if dtype is not None:
  164. A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
  165. perm = backshift_permutation(dim[0], dim[1], A.ndim)
  166. result = max_min(svdvals(prims.transpose(A, perm)), dim=-1)
  167. if keepdim:
  168. inv_perm = inverse_permutation(perm)
  169. result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
  170. return result
  171. else: # 1, -1, inf, -inf
  172. dim0, dim1 = dim
  173. if abs_ord == float("inf"):
  174. dim0, dim1 = dim1, dim0
  175. if not keepdim and (dim0 < dim1):
  176. dim1 -= 1
  177. return max_min(
  178. vector_norm(A, 1.0, dim=dim0, keepdim=keepdim, dtype=dtype), dim1
  179. )
  180. # CompositeImplicitAutograd
  181. @out_wrapper(exact_dtype=True)
  182. def norm(
  183. A: TensorLikeType,
  184. ord: Optional[Union[float, str]] = None,
  185. dim: Optional[DimsType] = None,
  186. keepdim: bool = False,
  187. *,
  188. dtype: Optional[torch.dtype] = None,
  189. ) -> TensorLikeType:
  190. if dim is not None:
  191. if isinstance(dim, Dim):
  192. dim = (dim,) # type: ignore[assignment]
  193. check(
  194. len(dim) in (1, 2),
  195. lambda: "linalg.norm: If dim is specified, it must be of length 1 or 2. Got {dim}",
  196. )
  197. elif ord is not None:
  198. check(
  199. A.ndim in (1, 2),
  200. lambda: "linalg.norm: If dim is not specified but ord is, the input must be 1D or 2D. Got {A.ndim}D",
  201. )
  202. if ord is not None and (
  203. (dim is not None and len(dim) == 2) or (dim is None and A.ndim == 2)
  204. ):
  205. if dim is None:
  206. dim = (0, 1)
  207. return matrix_norm(A, ord, dim, keepdim, dtype=dtype)
  208. else:
  209. if ord is None:
  210. ord = 2.0
  211. return vector_norm(A, ord, dim, keepdim, dtype=dtype)
  212. # CompositeImplicitAutograd
  213. @out_wrapper("U", "S", "Vh", exact_dtype=True)
  214. def svd(A: TensorLikeType, full_matrices: bool = True) -> Tuple[Tensor, Tensor, Tensor]:
  215. return prims.svd(A, full_matrices=full_matrices)
  216. # CompositeImplicitAutograd
  217. @out_wrapper(exact_dtype=True)
  218. def svdvals(A: TensorLikeType) -> Tensor:
  219. return svd(A, full_matrices=False)[1]