__init__.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. """
  2. A NumPy sub-namespace that conforms to the Python array API standard.
  3. This submodule accompanies NEP 47, which proposes its inclusion in NumPy. It
  4. is still considered experimental, and will issue a warning when imported.
  5. This is a proof-of-concept namespace that wraps the corresponding NumPy
  6. functions to give a conforming implementation of the Python array API standard
  7. (https://data-apis.github.io/array-api/latest/). The standard is currently in
  8. an RFC phase and comments on it are both welcome and encouraged. Comments
  9. should be made either at https://github.com/data-apis/array-api or at
  10. https://github.com/data-apis/consortium-feedback/discussions.
  11. NumPy already follows the proposed spec for the most part, so this module
  12. serves mostly as a thin wrapper around it. However, NumPy also implements a
  13. lot of behavior that is not included in the spec, so this serves as a
  14. restricted subset of the API. Only those functions that are part of the spec
  15. are included in this namespace, and all functions are given with the exact
  16. signature given in the spec, including the use of position-only arguments, and
  17. omitting any extra keyword arguments implemented by NumPy but not part of the
  18. spec. The behavior of some functions is also modified from the NumPy behavior
  19. to conform to the standard. Note that the underlying array object itself is
  20. wrapped in a wrapper Array() class, but is otherwise unchanged. This submodule
  21. is implemented in pure Python with no C extensions.
  22. The array API spec is designed as a "minimal API subset" and explicitly allows
  23. libraries to include behaviors not specified by it. But users of this module
  24. that intend to write portable code should be aware that only those behaviors
  25. that are listed in the spec are guaranteed to be implemented across libraries.
  26. Consequently, the NumPy implementation was chosen to be both conforming and
  27. minimal, so that users can use this implementation of the array API namespace
  28. and be sure that behaviors that it defines will be available in conforming
  29. namespaces from other libraries.
  30. A few notes about the current state of this submodule:
  31. - There is a test suite that tests modules against the array API standard at
  32. https://github.com/data-apis/array-api-tests. The test suite is still a work
  33. in progress, but the existing tests pass on this module, with a few
  34. exceptions:
  35. - DLPack support (see https://github.com/data-apis/array-api/pull/106) is
  36. not included here, as it requires a full implementation in NumPy proper
  37. first.
  38. The test suite is not yet complete, and even the tests that exist are not
  39. guaranteed to give a comprehensive coverage of the spec. Therefore, when
  40. reviewing and using this submodule, you should refer to the standard
  41. documents themselves. There are some tests in numpy.array_api.tests, but
  42. they primarily focus on things that are not tested by the official array API
  43. test suite.
  44. - There is a custom array object, numpy.array_api.Array, which is returned by
  45. all functions in this module. All functions in the array API namespace
  46. implicitly assume that they will only receive this object as input. The only
  47. way to create instances of this object is to use one of the array creation
  48. functions. It does not have a public constructor on the object itself. The
  49. object is a small wrapper class around numpy.ndarray. The main purpose of it
  50. is to restrict the namespace of the array object to only those dtypes and
  51. only those methods that are required by the spec, as well as to limit/change
  52. certain behavior that differs in the spec. In particular:
  53. - The array API namespace does not have scalar objects, only 0-D arrays.
  54. Operations on Array that would create a scalar in NumPy create a 0-D
  55. array.
  56. - Indexing: Only a subset of indices supported by NumPy are required by the
  57. spec. The Array object restricts indexing to only allow those types of
  58. indices that are required by the spec. See the docstring of the
  59. numpy.array_api.Array._validate_indices helper function for more
  60. information.
  61. - Type promotion: Some type promotion rules are different in the spec. In
  62. particular, the spec does not have any value-based casting. The spec also
  63. does not require cross-kind casting, like integer -> floating-point. Only
  64. those promotions that are explicitly required by the array API
  65. specification are allowed in this module. See NEP 47 for more info.
  66. - Functions do not automatically call asarray() on their input, and will not
  67. work if the input type is not Array. The exception is array creation
  68. functions, and Python operators on the Array object, which accept Python
  69. scalars of the same type as the array dtype.
  70. - All functions include type annotations, corresponding to those given in the
  71. spec (see _typing.py for definitions of some custom types). These do not
  72. currently fully pass mypy due to some limitations in mypy.
  73. - Dtype objects are just the NumPy dtype objects, e.g., float64 =
  74. np.dtype('float64'). The spec does not require any behavior on these dtype
  75. objects other than that they be accessible by name and be comparable by
  76. equality, but it was considered too much extra complexity to create custom
  77. objects to represent dtypes.
  78. - All places where the implementations in this submodule are known to deviate
  79. from their corresponding functions in NumPy are marked with "# Note:"
  80. comments.
  81. Still TODO in this module are:
  82. - DLPack support for numpy.ndarray is still in progress. See
  83. https://github.com/numpy/numpy/pull/19083.
  84. - The copy=False keyword argument to asarray() is not yet implemented. This
  85. requires support in numpy.asarray() first.
  86. - Some functions are not yet fully tested in the array API test suite, and may
  87. require updates that are not yet known until the tests are written.
  88. - The spec is still in an RFC phase and may still have minor updates, which
  89. will need to be reflected here.
  90. - Complex number support in array API spec is planned but not yet finalized,
  91. as are the fft extension and certain linear algebra functions such as eig
  92. that require complex dtypes.
  93. """
  94. import warnings
  95. warnings.warn(
  96. "The numpy.array_api submodule is still experimental. See NEP 47.", stacklevel=2
  97. )
  98. __array_api_version__ = "2021.12"
  99. __all__ = ["__array_api_version__"]
  100. from ._constants import e, inf, nan, pi
  101. __all__ += ["e", "inf", "nan", "pi"]
  102. from ._creation_functions import (
  103. asarray,
  104. arange,
  105. empty,
  106. empty_like,
  107. eye,
  108. from_dlpack,
  109. full,
  110. full_like,
  111. linspace,
  112. meshgrid,
  113. ones,
  114. ones_like,
  115. tril,
  116. triu,
  117. zeros,
  118. zeros_like,
  119. )
  120. __all__ += [
  121. "asarray",
  122. "arange",
  123. "empty",
  124. "empty_like",
  125. "eye",
  126. "from_dlpack",
  127. "full",
  128. "full_like",
  129. "linspace",
  130. "meshgrid",
  131. "ones",
  132. "ones_like",
  133. "tril",
  134. "triu",
  135. "zeros",
  136. "zeros_like",
  137. ]
  138. from ._data_type_functions import (
  139. astype,
  140. broadcast_arrays,
  141. broadcast_to,
  142. can_cast,
  143. finfo,
  144. iinfo,
  145. result_type,
  146. )
  147. __all__ += [
  148. "astype",
  149. "broadcast_arrays",
  150. "broadcast_to",
  151. "can_cast",
  152. "finfo",
  153. "iinfo",
  154. "result_type",
  155. ]
  156. from ._dtypes import (
  157. int8,
  158. int16,
  159. int32,
  160. int64,
  161. uint8,
  162. uint16,
  163. uint32,
  164. uint64,
  165. float32,
  166. float64,
  167. bool,
  168. )
  169. __all__ += [
  170. "int8",
  171. "int16",
  172. "int32",
  173. "int64",
  174. "uint8",
  175. "uint16",
  176. "uint32",
  177. "uint64",
  178. "float32",
  179. "float64",
  180. "bool",
  181. ]
  182. from ._elementwise_functions import (
  183. abs,
  184. acos,
  185. acosh,
  186. add,
  187. asin,
  188. asinh,
  189. atan,
  190. atan2,
  191. atanh,
  192. bitwise_and,
  193. bitwise_left_shift,
  194. bitwise_invert,
  195. bitwise_or,
  196. bitwise_right_shift,
  197. bitwise_xor,
  198. ceil,
  199. cos,
  200. cosh,
  201. divide,
  202. equal,
  203. exp,
  204. expm1,
  205. floor,
  206. floor_divide,
  207. greater,
  208. greater_equal,
  209. isfinite,
  210. isinf,
  211. isnan,
  212. less,
  213. less_equal,
  214. log,
  215. log1p,
  216. log2,
  217. log10,
  218. logaddexp,
  219. logical_and,
  220. logical_not,
  221. logical_or,
  222. logical_xor,
  223. multiply,
  224. negative,
  225. not_equal,
  226. positive,
  227. pow,
  228. remainder,
  229. round,
  230. sign,
  231. sin,
  232. sinh,
  233. square,
  234. sqrt,
  235. subtract,
  236. tan,
  237. tanh,
  238. trunc,
  239. )
  240. __all__ += [
  241. "abs",
  242. "acos",
  243. "acosh",
  244. "add",
  245. "asin",
  246. "asinh",
  247. "atan",
  248. "atan2",
  249. "atanh",
  250. "bitwise_and",
  251. "bitwise_left_shift",
  252. "bitwise_invert",
  253. "bitwise_or",
  254. "bitwise_right_shift",
  255. "bitwise_xor",
  256. "ceil",
  257. "cos",
  258. "cosh",
  259. "divide",
  260. "equal",
  261. "exp",
  262. "expm1",
  263. "floor",
  264. "floor_divide",
  265. "greater",
  266. "greater_equal",
  267. "isfinite",
  268. "isinf",
  269. "isnan",
  270. "less",
  271. "less_equal",
  272. "log",
  273. "log1p",
  274. "log2",
  275. "log10",
  276. "logaddexp",
  277. "logical_and",
  278. "logical_not",
  279. "logical_or",
  280. "logical_xor",
  281. "multiply",
  282. "negative",
  283. "not_equal",
  284. "positive",
  285. "pow",
  286. "remainder",
  287. "round",
  288. "sign",
  289. "sin",
  290. "sinh",
  291. "square",
  292. "sqrt",
  293. "subtract",
  294. "tan",
  295. "tanh",
  296. "trunc",
  297. ]
  298. # linalg is an extension in the array API spec, which is a sub-namespace. Only
  299. # a subset of functions in it are imported into the top-level namespace.
  300. from . import linalg
  301. __all__ += ["linalg"]
  302. from .linalg import matmul, tensordot, matrix_transpose, vecdot
  303. __all__ += ["matmul", "tensordot", "matrix_transpose", "vecdot"]
  304. from ._manipulation_functions import (
  305. concat,
  306. expand_dims,
  307. flip,
  308. permute_dims,
  309. reshape,
  310. roll,
  311. squeeze,
  312. stack,
  313. )
  314. __all__ += ["concat", "expand_dims", "flip", "permute_dims", "reshape", "roll", "squeeze", "stack"]
  315. from ._searching_functions import argmax, argmin, nonzero, where
  316. __all__ += ["argmax", "argmin", "nonzero", "where"]
  317. from ._set_functions import unique_all, unique_counts, unique_inverse, unique_values
  318. __all__ += ["unique_all", "unique_counts", "unique_inverse", "unique_values"]
  319. from ._sorting_functions import argsort, sort
  320. __all__ += ["argsort", "sort"]
  321. from ._statistical_functions import max, mean, min, prod, std, sum, var
  322. __all__ += ["max", "mean", "min", "prod", "std", "sum", "var"]
  323. from ._utility_functions import all, any
  324. __all__ += ["all", "any"]