einsumfunc.pyi 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. from collections.abc import Sequence
  2. from typing import TypeVar, Any, overload, Union, Literal
  3. from numpy import (
  4. ndarray,
  5. dtype,
  6. bool_,
  7. unsignedinteger,
  8. signedinteger,
  9. floating,
  10. complexfloating,
  11. number,
  12. _OrderKACF,
  13. )
  14. from numpy._typing import (
  15. _ArrayLikeBool_co,
  16. _ArrayLikeUInt_co,
  17. _ArrayLikeInt_co,
  18. _ArrayLikeFloat_co,
  19. _ArrayLikeComplex_co,
  20. _DTypeLikeBool,
  21. _DTypeLikeUInt,
  22. _DTypeLikeInt,
  23. _DTypeLikeFloat,
  24. _DTypeLikeComplex,
  25. _DTypeLikeComplex_co,
  26. )
  27. _ArrayType = TypeVar(
  28. "_ArrayType",
  29. bound=ndarray[Any, dtype[Union[bool_, number[Any]]]],
  30. )
  31. _OptimizeKind = None | bool | Literal["greedy", "optimal"] | Sequence[Any]
  32. _CastingSafe = Literal["no", "equiv", "safe", "same_kind"]
  33. _CastingUnsafe = Literal["unsafe"]
  34. __all__: list[str]
  35. # TODO: Properly handle the `casting`-based combinatorics
  36. # TODO: We need to evaluate the content `__subscripts` in order
  37. # to identify whether or an array or scalar is returned. At a cursory
  38. # glance this seems like something that can quite easily be done with
  39. # a mypy plugin.
  40. # Something like `is_scalar = bool(__subscripts.partition("->")[-1])`
  41. @overload
  42. def einsum(
  43. subscripts: str | _ArrayLikeInt_co,
  44. /,
  45. *operands: _ArrayLikeBool_co,
  46. out: None = ...,
  47. dtype: None | _DTypeLikeBool = ...,
  48. order: _OrderKACF = ...,
  49. casting: _CastingSafe = ...,
  50. optimize: _OptimizeKind = ...,
  51. ) -> Any: ...
  52. @overload
  53. def einsum(
  54. subscripts: str | _ArrayLikeInt_co,
  55. /,
  56. *operands: _ArrayLikeUInt_co,
  57. out: None = ...,
  58. dtype: None | _DTypeLikeUInt = ...,
  59. order: _OrderKACF = ...,
  60. casting: _CastingSafe = ...,
  61. optimize: _OptimizeKind = ...,
  62. ) -> Any: ...
  63. @overload
  64. def einsum(
  65. subscripts: str | _ArrayLikeInt_co,
  66. /,
  67. *operands: _ArrayLikeInt_co,
  68. out: None = ...,
  69. dtype: None | _DTypeLikeInt = ...,
  70. order: _OrderKACF = ...,
  71. casting: _CastingSafe = ...,
  72. optimize: _OptimizeKind = ...,
  73. ) -> Any: ...
  74. @overload
  75. def einsum(
  76. subscripts: str | _ArrayLikeInt_co,
  77. /,
  78. *operands: _ArrayLikeFloat_co,
  79. out: None = ...,
  80. dtype: None | _DTypeLikeFloat = ...,
  81. order: _OrderKACF = ...,
  82. casting: _CastingSafe = ...,
  83. optimize: _OptimizeKind = ...,
  84. ) -> Any: ...
  85. @overload
  86. def einsum(
  87. subscripts: str | _ArrayLikeInt_co,
  88. /,
  89. *operands: _ArrayLikeComplex_co,
  90. out: None = ...,
  91. dtype: None | _DTypeLikeComplex = ...,
  92. order: _OrderKACF = ...,
  93. casting: _CastingSafe = ...,
  94. optimize: _OptimizeKind = ...,
  95. ) -> Any: ...
  96. @overload
  97. def einsum(
  98. subscripts: str | _ArrayLikeInt_co,
  99. /,
  100. *operands: Any,
  101. casting: _CastingUnsafe,
  102. dtype: None | _DTypeLikeComplex_co = ...,
  103. out: None = ...,
  104. order: _OrderKACF = ...,
  105. optimize: _OptimizeKind = ...,
  106. ) -> Any: ...
  107. @overload
  108. def einsum(
  109. subscripts: str | _ArrayLikeInt_co,
  110. /,
  111. *operands: _ArrayLikeComplex_co,
  112. out: _ArrayType,
  113. dtype: None | _DTypeLikeComplex_co = ...,
  114. order: _OrderKACF = ...,
  115. casting: _CastingSafe = ...,
  116. optimize: _OptimizeKind = ...,
  117. ) -> _ArrayType: ...
  118. @overload
  119. def einsum(
  120. subscripts: str | _ArrayLikeInt_co,
  121. /,
  122. *operands: Any,
  123. out: _ArrayType,
  124. casting: _CastingUnsafe,
  125. dtype: None | _DTypeLikeComplex_co = ...,
  126. order: _OrderKACF = ...,
  127. optimize: _OptimizeKind = ...,
  128. ) -> _ArrayType: ...
  129. # NOTE: `einsum_call` is a hidden kwarg unavailable for public use.
  130. # It is therefore excluded from the signatures below.
  131. # NOTE: In practice the list consists of a `str` (first element)
  132. # and a variable number of integer tuples.
  133. def einsum_path(
  134. subscripts: str | _ArrayLikeInt_co,
  135. /,
  136. *operands: _ArrayLikeComplex_co,
  137. optimize: _OptimizeKind = ...,
  138. ) -> tuple[list[Any], str]: ...