_elementwise_functions.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729
  1. from __future__ import annotations
  2. from ._dtypes import (
  3. _boolean_dtypes,
  4. _floating_dtypes,
  5. _integer_dtypes,
  6. _integer_or_boolean_dtypes,
  7. _numeric_dtypes,
  8. _result_type,
  9. )
  10. from ._array_object import Array
  11. import numpy as np
  12. def abs(x: Array, /) -> Array:
  13. """
  14. Array API compatible wrapper for :py:func:`np.abs <numpy.abs>`.
  15. See its docstring for more information.
  16. """
  17. if x.dtype not in _numeric_dtypes:
  18. raise TypeError("Only numeric dtypes are allowed in abs")
  19. return Array._new(np.abs(x._array))
  20. # Note: the function name is different here
  21. def acos(x: Array, /) -> Array:
  22. """
  23. Array API compatible wrapper for :py:func:`np.arccos <numpy.arccos>`.
  24. See its docstring for more information.
  25. """
  26. if x.dtype not in _floating_dtypes:
  27. raise TypeError("Only floating-point dtypes are allowed in acos")
  28. return Array._new(np.arccos(x._array))
  29. # Note: the function name is different here
  30. def acosh(x: Array, /) -> Array:
  31. """
  32. Array API compatible wrapper for :py:func:`np.arccosh <numpy.arccosh>`.
  33. See its docstring for more information.
  34. """
  35. if x.dtype not in _floating_dtypes:
  36. raise TypeError("Only floating-point dtypes are allowed in acosh")
  37. return Array._new(np.arccosh(x._array))
  38. def add(x1: Array, x2: Array, /) -> Array:
  39. """
  40. Array API compatible wrapper for :py:func:`np.add <numpy.add>`.
  41. See its docstring for more information.
  42. """
  43. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  44. raise TypeError("Only numeric dtypes are allowed in add")
  45. # Call result type here just to raise on disallowed type combinations
  46. _result_type(x1.dtype, x2.dtype)
  47. x1, x2 = Array._normalize_two_args(x1, x2)
  48. return Array._new(np.add(x1._array, x2._array))
  49. # Note: the function name is different here
  50. def asin(x: Array, /) -> Array:
  51. """
  52. Array API compatible wrapper for :py:func:`np.arcsin <numpy.arcsin>`.
  53. See its docstring for more information.
  54. """
  55. if x.dtype not in _floating_dtypes:
  56. raise TypeError("Only floating-point dtypes are allowed in asin")
  57. return Array._new(np.arcsin(x._array))
  58. # Note: the function name is different here
  59. def asinh(x: Array, /) -> Array:
  60. """
  61. Array API compatible wrapper for :py:func:`np.arcsinh <numpy.arcsinh>`.
  62. See its docstring for more information.
  63. """
  64. if x.dtype not in _floating_dtypes:
  65. raise TypeError("Only floating-point dtypes are allowed in asinh")
  66. return Array._new(np.arcsinh(x._array))
  67. # Note: the function name is different here
  68. def atan(x: Array, /) -> Array:
  69. """
  70. Array API compatible wrapper for :py:func:`np.arctan <numpy.arctan>`.
  71. See its docstring for more information.
  72. """
  73. if x.dtype not in _floating_dtypes:
  74. raise TypeError("Only floating-point dtypes are allowed in atan")
  75. return Array._new(np.arctan(x._array))
  76. # Note: the function name is different here
  77. def atan2(x1: Array, x2: Array, /) -> Array:
  78. """
  79. Array API compatible wrapper for :py:func:`np.arctan2 <numpy.arctan2>`.
  80. See its docstring for more information.
  81. """
  82. if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
  83. raise TypeError("Only floating-point dtypes are allowed in atan2")
  84. # Call result type here just to raise on disallowed type combinations
  85. _result_type(x1.dtype, x2.dtype)
  86. x1, x2 = Array._normalize_two_args(x1, x2)
  87. return Array._new(np.arctan2(x1._array, x2._array))
  88. # Note: the function name is different here
  89. def atanh(x: Array, /) -> Array:
  90. """
  91. Array API compatible wrapper for :py:func:`np.arctanh <numpy.arctanh>`.
  92. See its docstring for more information.
  93. """
  94. if x.dtype not in _floating_dtypes:
  95. raise TypeError("Only floating-point dtypes are allowed in atanh")
  96. return Array._new(np.arctanh(x._array))
  97. def bitwise_and(x1: Array, x2: Array, /) -> Array:
  98. """
  99. Array API compatible wrapper for :py:func:`np.bitwise_and <numpy.bitwise_and>`.
  100. See its docstring for more information.
  101. """
  102. if (
  103. x1.dtype not in _integer_or_boolean_dtypes
  104. or x2.dtype not in _integer_or_boolean_dtypes
  105. ):
  106. raise TypeError("Only integer or boolean dtypes are allowed in bitwise_and")
  107. # Call result type here just to raise on disallowed type combinations
  108. _result_type(x1.dtype, x2.dtype)
  109. x1, x2 = Array._normalize_two_args(x1, x2)
  110. return Array._new(np.bitwise_and(x1._array, x2._array))
  111. # Note: the function name is different here
  112. def bitwise_left_shift(x1: Array, x2: Array, /) -> Array:
  113. """
  114. Array API compatible wrapper for :py:func:`np.left_shift <numpy.left_shift>`.
  115. See its docstring for more information.
  116. """
  117. if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
  118. raise TypeError("Only integer dtypes are allowed in bitwise_left_shift")
  119. # Call result type here just to raise on disallowed type combinations
  120. _result_type(x1.dtype, x2.dtype)
  121. x1, x2 = Array._normalize_two_args(x1, x2)
  122. # Note: bitwise_left_shift is only defined for x2 nonnegative.
  123. if np.any(x2._array < 0):
  124. raise ValueError("bitwise_left_shift(x1, x2) is only defined for x2 >= 0")
  125. return Array._new(np.left_shift(x1._array, x2._array))
  126. # Note: the function name is different here
  127. def bitwise_invert(x: Array, /) -> Array:
  128. """
  129. Array API compatible wrapper for :py:func:`np.invert <numpy.invert>`.
  130. See its docstring for more information.
  131. """
  132. if x.dtype not in _integer_or_boolean_dtypes:
  133. raise TypeError("Only integer or boolean dtypes are allowed in bitwise_invert")
  134. return Array._new(np.invert(x._array))
  135. def bitwise_or(x1: Array, x2: Array, /) -> Array:
  136. """
  137. Array API compatible wrapper for :py:func:`np.bitwise_or <numpy.bitwise_or>`.
  138. See its docstring for more information.
  139. """
  140. if (
  141. x1.dtype not in _integer_or_boolean_dtypes
  142. or x2.dtype not in _integer_or_boolean_dtypes
  143. ):
  144. raise TypeError("Only integer or boolean dtypes are allowed in bitwise_or")
  145. # Call result type here just to raise on disallowed type combinations
  146. _result_type(x1.dtype, x2.dtype)
  147. x1, x2 = Array._normalize_two_args(x1, x2)
  148. return Array._new(np.bitwise_or(x1._array, x2._array))
  149. # Note: the function name is different here
  150. def bitwise_right_shift(x1: Array, x2: Array, /) -> Array:
  151. """
  152. Array API compatible wrapper for :py:func:`np.right_shift <numpy.right_shift>`.
  153. See its docstring for more information.
  154. """
  155. if x1.dtype not in _integer_dtypes or x2.dtype not in _integer_dtypes:
  156. raise TypeError("Only integer dtypes are allowed in bitwise_right_shift")
  157. # Call result type here just to raise on disallowed type combinations
  158. _result_type(x1.dtype, x2.dtype)
  159. x1, x2 = Array._normalize_two_args(x1, x2)
  160. # Note: bitwise_right_shift is only defined for x2 nonnegative.
  161. if np.any(x2._array < 0):
  162. raise ValueError("bitwise_right_shift(x1, x2) is only defined for x2 >= 0")
  163. return Array._new(np.right_shift(x1._array, x2._array))
  164. def bitwise_xor(x1: Array, x2: Array, /) -> Array:
  165. """
  166. Array API compatible wrapper for :py:func:`np.bitwise_xor <numpy.bitwise_xor>`.
  167. See its docstring for more information.
  168. """
  169. if (
  170. x1.dtype not in _integer_or_boolean_dtypes
  171. or x2.dtype not in _integer_or_boolean_dtypes
  172. ):
  173. raise TypeError("Only integer or boolean dtypes are allowed in bitwise_xor")
  174. # Call result type here just to raise on disallowed type combinations
  175. _result_type(x1.dtype, x2.dtype)
  176. x1, x2 = Array._normalize_two_args(x1, x2)
  177. return Array._new(np.bitwise_xor(x1._array, x2._array))
  178. def ceil(x: Array, /) -> Array:
  179. """
  180. Array API compatible wrapper for :py:func:`np.ceil <numpy.ceil>`.
  181. See its docstring for more information.
  182. """
  183. if x.dtype not in _numeric_dtypes:
  184. raise TypeError("Only numeric dtypes are allowed in ceil")
  185. if x.dtype in _integer_dtypes:
  186. # Note: The return dtype of ceil is the same as the input
  187. return x
  188. return Array._new(np.ceil(x._array))
  189. def cos(x: Array, /) -> Array:
  190. """
  191. Array API compatible wrapper for :py:func:`np.cos <numpy.cos>`.
  192. See its docstring for more information.
  193. """
  194. if x.dtype not in _floating_dtypes:
  195. raise TypeError("Only floating-point dtypes are allowed in cos")
  196. return Array._new(np.cos(x._array))
  197. def cosh(x: Array, /) -> Array:
  198. """
  199. Array API compatible wrapper for :py:func:`np.cosh <numpy.cosh>`.
  200. See its docstring for more information.
  201. """
  202. if x.dtype not in _floating_dtypes:
  203. raise TypeError("Only floating-point dtypes are allowed in cosh")
  204. return Array._new(np.cosh(x._array))
  205. def divide(x1: Array, x2: Array, /) -> Array:
  206. """
  207. Array API compatible wrapper for :py:func:`np.divide <numpy.divide>`.
  208. See its docstring for more information.
  209. """
  210. if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
  211. raise TypeError("Only floating-point dtypes are allowed in divide")
  212. # Call result type here just to raise on disallowed type combinations
  213. _result_type(x1.dtype, x2.dtype)
  214. x1, x2 = Array._normalize_two_args(x1, x2)
  215. return Array._new(np.divide(x1._array, x2._array))
  216. def equal(x1: Array, x2: Array, /) -> Array:
  217. """
  218. Array API compatible wrapper for :py:func:`np.equal <numpy.equal>`.
  219. See its docstring for more information.
  220. """
  221. # Call result type here just to raise on disallowed type combinations
  222. _result_type(x1.dtype, x2.dtype)
  223. x1, x2 = Array._normalize_two_args(x1, x2)
  224. return Array._new(np.equal(x1._array, x2._array))
  225. def exp(x: Array, /) -> Array:
  226. """
  227. Array API compatible wrapper for :py:func:`np.exp <numpy.exp>`.
  228. See its docstring for more information.
  229. """
  230. if x.dtype not in _floating_dtypes:
  231. raise TypeError("Only floating-point dtypes are allowed in exp")
  232. return Array._new(np.exp(x._array))
  233. def expm1(x: Array, /) -> Array:
  234. """
  235. Array API compatible wrapper for :py:func:`np.expm1 <numpy.expm1>`.
  236. See its docstring for more information.
  237. """
  238. if x.dtype not in _floating_dtypes:
  239. raise TypeError("Only floating-point dtypes are allowed in expm1")
  240. return Array._new(np.expm1(x._array))
  241. def floor(x: Array, /) -> Array:
  242. """
  243. Array API compatible wrapper for :py:func:`np.floor <numpy.floor>`.
  244. See its docstring for more information.
  245. """
  246. if x.dtype not in _numeric_dtypes:
  247. raise TypeError("Only numeric dtypes are allowed in floor")
  248. if x.dtype in _integer_dtypes:
  249. # Note: The return dtype of floor is the same as the input
  250. return x
  251. return Array._new(np.floor(x._array))
  252. def floor_divide(x1: Array, x2: Array, /) -> Array:
  253. """
  254. Array API compatible wrapper for :py:func:`np.floor_divide <numpy.floor_divide>`.
  255. See its docstring for more information.
  256. """
  257. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  258. raise TypeError("Only numeric dtypes are allowed in floor_divide")
  259. # Call result type here just to raise on disallowed type combinations
  260. _result_type(x1.dtype, x2.dtype)
  261. x1, x2 = Array._normalize_two_args(x1, x2)
  262. return Array._new(np.floor_divide(x1._array, x2._array))
  263. def greater(x1: Array, x2: Array, /) -> Array:
  264. """
  265. Array API compatible wrapper for :py:func:`np.greater <numpy.greater>`.
  266. See its docstring for more information.
  267. """
  268. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  269. raise TypeError("Only numeric dtypes are allowed in greater")
  270. # Call result type here just to raise on disallowed type combinations
  271. _result_type(x1.dtype, x2.dtype)
  272. x1, x2 = Array._normalize_two_args(x1, x2)
  273. return Array._new(np.greater(x1._array, x2._array))
  274. def greater_equal(x1: Array, x2: Array, /) -> Array:
  275. """
  276. Array API compatible wrapper for :py:func:`np.greater_equal <numpy.greater_equal>`.
  277. See its docstring for more information.
  278. """
  279. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  280. raise TypeError("Only numeric dtypes are allowed in greater_equal")
  281. # Call result type here just to raise on disallowed type combinations
  282. _result_type(x1.dtype, x2.dtype)
  283. x1, x2 = Array._normalize_two_args(x1, x2)
  284. return Array._new(np.greater_equal(x1._array, x2._array))
  285. def isfinite(x: Array, /) -> Array:
  286. """
  287. Array API compatible wrapper for :py:func:`np.isfinite <numpy.isfinite>`.
  288. See its docstring for more information.
  289. """
  290. if x.dtype not in _numeric_dtypes:
  291. raise TypeError("Only numeric dtypes are allowed in isfinite")
  292. return Array._new(np.isfinite(x._array))
  293. def isinf(x: Array, /) -> Array:
  294. """
  295. Array API compatible wrapper for :py:func:`np.isinf <numpy.isinf>`.
  296. See its docstring for more information.
  297. """
  298. if x.dtype not in _numeric_dtypes:
  299. raise TypeError("Only numeric dtypes are allowed in isinf")
  300. return Array._new(np.isinf(x._array))
  301. def isnan(x: Array, /) -> Array:
  302. """
  303. Array API compatible wrapper for :py:func:`np.isnan <numpy.isnan>`.
  304. See its docstring for more information.
  305. """
  306. if x.dtype not in _numeric_dtypes:
  307. raise TypeError("Only numeric dtypes are allowed in isnan")
  308. return Array._new(np.isnan(x._array))
  309. def less(x1: Array, x2: Array, /) -> Array:
  310. """
  311. Array API compatible wrapper for :py:func:`np.less <numpy.less>`.
  312. See its docstring for more information.
  313. """
  314. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  315. raise TypeError("Only numeric dtypes are allowed in less")
  316. # Call result type here just to raise on disallowed type combinations
  317. _result_type(x1.dtype, x2.dtype)
  318. x1, x2 = Array._normalize_two_args(x1, x2)
  319. return Array._new(np.less(x1._array, x2._array))
  320. def less_equal(x1: Array, x2: Array, /) -> Array:
  321. """
  322. Array API compatible wrapper for :py:func:`np.less_equal <numpy.less_equal>`.
  323. See its docstring for more information.
  324. """
  325. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  326. raise TypeError("Only numeric dtypes are allowed in less_equal")
  327. # Call result type here just to raise on disallowed type combinations
  328. _result_type(x1.dtype, x2.dtype)
  329. x1, x2 = Array._normalize_two_args(x1, x2)
  330. return Array._new(np.less_equal(x1._array, x2._array))
  331. def log(x: Array, /) -> Array:
  332. """
  333. Array API compatible wrapper for :py:func:`np.log <numpy.log>`.
  334. See its docstring for more information.
  335. """
  336. if x.dtype not in _floating_dtypes:
  337. raise TypeError("Only floating-point dtypes are allowed in log")
  338. return Array._new(np.log(x._array))
  339. def log1p(x: Array, /) -> Array:
  340. """
  341. Array API compatible wrapper for :py:func:`np.log1p <numpy.log1p>`.
  342. See its docstring for more information.
  343. """
  344. if x.dtype not in _floating_dtypes:
  345. raise TypeError("Only floating-point dtypes are allowed in log1p")
  346. return Array._new(np.log1p(x._array))
  347. def log2(x: Array, /) -> Array:
  348. """
  349. Array API compatible wrapper for :py:func:`np.log2 <numpy.log2>`.
  350. See its docstring for more information.
  351. """
  352. if x.dtype not in _floating_dtypes:
  353. raise TypeError("Only floating-point dtypes are allowed in log2")
  354. return Array._new(np.log2(x._array))
  355. def log10(x: Array, /) -> Array:
  356. """
  357. Array API compatible wrapper for :py:func:`np.log10 <numpy.log10>`.
  358. See its docstring for more information.
  359. """
  360. if x.dtype not in _floating_dtypes:
  361. raise TypeError("Only floating-point dtypes are allowed in log10")
  362. return Array._new(np.log10(x._array))
  363. def logaddexp(x1: Array, x2: Array) -> Array:
  364. """
  365. Array API compatible wrapper for :py:func:`np.logaddexp <numpy.logaddexp>`.
  366. See its docstring for more information.
  367. """
  368. if x1.dtype not in _floating_dtypes or x2.dtype not in _floating_dtypes:
  369. raise TypeError("Only floating-point dtypes are allowed in logaddexp")
  370. # Call result type here just to raise on disallowed type combinations
  371. _result_type(x1.dtype, x2.dtype)
  372. x1, x2 = Array._normalize_two_args(x1, x2)
  373. return Array._new(np.logaddexp(x1._array, x2._array))
  374. def logical_and(x1: Array, x2: Array, /) -> Array:
  375. """
  376. Array API compatible wrapper for :py:func:`np.logical_and <numpy.logical_and>`.
  377. See its docstring for more information.
  378. """
  379. if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
  380. raise TypeError("Only boolean dtypes are allowed in logical_and")
  381. # Call result type here just to raise on disallowed type combinations
  382. _result_type(x1.dtype, x2.dtype)
  383. x1, x2 = Array._normalize_two_args(x1, x2)
  384. return Array._new(np.logical_and(x1._array, x2._array))
  385. def logical_not(x: Array, /) -> Array:
  386. """
  387. Array API compatible wrapper for :py:func:`np.logical_not <numpy.logical_not>`.
  388. See its docstring for more information.
  389. """
  390. if x.dtype not in _boolean_dtypes:
  391. raise TypeError("Only boolean dtypes are allowed in logical_not")
  392. return Array._new(np.logical_not(x._array))
  393. def logical_or(x1: Array, x2: Array, /) -> Array:
  394. """
  395. Array API compatible wrapper for :py:func:`np.logical_or <numpy.logical_or>`.
  396. See its docstring for more information.
  397. """
  398. if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
  399. raise TypeError("Only boolean dtypes are allowed in logical_or")
  400. # Call result type here just to raise on disallowed type combinations
  401. _result_type(x1.dtype, x2.dtype)
  402. x1, x2 = Array._normalize_two_args(x1, x2)
  403. return Array._new(np.logical_or(x1._array, x2._array))
  404. def logical_xor(x1: Array, x2: Array, /) -> Array:
  405. """
  406. Array API compatible wrapper for :py:func:`np.logical_xor <numpy.logical_xor>`.
  407. See its docstring for more information.
  408. """
  409. if x1.dtype not in _boolean_dtypes or x2.dtype not in _boolean_dtypes:
  410. raise TypeError("Only boolean dtypes are allowed in logical_xor")
  411. # Call result type here just to raise on disallowed type combinations
  412. _result_type(x1.dtype, x2.dtype)
  413. x1, x2 = Array._normalize_two_args(x1, x2)
  414. return Array._new(np.logical_xor(x1._array, x2._array))
  415. def multiply(x1: Array, x2: Array, /) -> Array:
  416. """
  417. Array API compatible wrapper for :py:func:`np.multiply <numpy.multiply>`.
  418. See its docstring for more information.
  419. """
  420. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  421. raise TypeError("Only numeric dtypes are allowed in multiply")
  422. # Call result type here just to raise on disallowed type combinations
  423. _result_type(x1.dtype, x2.dtype)
  424. x1, x2 = Array._normalize_two_args(x1, x2)
  425. return Array._new(np.multiply(x1._array, x2._array))
  426. def negative(x: Array, /) -> Array:
  427. """
  428. Array API compatible wrapper for :py:func:`np.negative <numpy.negative>`.
  429. See its docstring for more information.
  430. """
  431. if x.dtype not in _numeric_dtypes:
  432. raise TypeError("Only numeric dtypes are allowed in negative")
  433. return Array._new(np.negative(x._array))
  434. def not_equal(x1: Array, x2: Array, /) -> Array:
  435. """
  436. Array API compatible wrapper for :py:func:`np.not_equal <numpy.not_equal>`.
  437. See its docstring for more information.
  438. """
  439. # Call result type here just to raise on disallowed type combinations
  440. _result_type(x1.dtype, x2.dtype)
  441. x1, x2 = Array._normalize_two_args(x1, x2)
  442. return Array._new(np.not_equal(x1._array, x2._array))
  443. def positive(x: Array, /) -> Array:
  444. """
  445. Array API compatible wrapper for :py:func:`np.positive <numpy.positive>`.
  446. See its docstring for more information.
  447. """
  448. if x.dtype not in _numeric_dtypes:
  449. raise TypeError("Only numeric dtypes are allowed in positive")
  450. return Array._new(np.positive(x._array))
  451. # Note: the function name is different here
  452. def pow(x1: Array, x2: Array, /) -> Array:
  453. """
  454. Array API compatible wrapper for :py:func:`np.power <numpy.power>`.
  455. See its docstring for more information.
  456. """
  457. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  458. raise TypeError("Only numeric dtypes are allowed in pow")
  459. # Call result type here just to raise on disallowed type combinations
  460. _result_type(x1.dtype, x2.dtype)
  461. x1, x2 = Array._normalize_two_args(x1, x2)
  462. return Array._new(np.power(x1._array, x2._array))
  463. def remainder(x1: Array, x2: Array, /) -> Array:
  464. """
  465. Array API compatible wrapper for :py:func:`np.remainder <numpy.remainder>`.
  466. See its docstring for more information.
  467. """
  468. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  469. raise TypeError("Only numeric dtypes are allowed in remainder")
  470. # Call result type here just to raise on disallowed type combinations
  471. _result_type(x1.dtype, x2.dtype)
  472. x1, x2 = Array._normalize_two_args(x1, x2)
  473. return Array._new(np.remainder(x1._array, x2._array))
  474. def round(x: Array, /) -> Array:
  475. """
  476. Array API compatible wrapper for :py:func:`np.round <numpy.round>`.
  477. See its docstring for more information.
  478. """
  479. if x.dtype not in _numeric_dtypes:
  480. raise TypeError("Only numeric dtypes are allowed in round")
  481. return Array._new(np.round(x._array))
  482. def sign(x: Array, /) -> Array:
  483. """
  484. Array API compatible wrapper for :py:func:`np.sign <numpy.sign>`.
  485. See its docstring for more information.
  486. """
  487. if x.dtype not in _numeric_dtypes:
  488. raise TypeError("Only numeric dtypes are allowed in sign")
  489. return Array._new(np.sign(x._array))
  490. def sin(x: Array, /) -> Array:
  491. """
  492. Array API compatible wrapper for :py:func:`np.sin <numpy.sin>`.
  493. See its docstring for more information.
  494. """
  495. if x.dtype not in _floating_dtypes:
  496. raise TypeError("Only floating-point dtypes are allowed in sin")
  497. return Array._new(np.sin(x._array))
  498. def sinh(x: Array, /) -> Array:
  499. """
  500. Array API compatible wrapper for :py:func:`np.sinh <numpy.sinh>`.
  501. See its docstring for more information.
  502. """
  503. if x.dtype not in _floating_dtypes:
  504. raise TypeError("Only floating-point dtypes are allowed in sinh")
  505. return Array._new(np.sinh(x._array))
  506. def square(x: Array, /) -> Array:
  507. """
  508. Array API compatible wrapper for :py:func:`np.square <numpy.square>`.
  509. See its docstring for more information.
  510. """
  511. if x.dtype not in _numeric_dtypes:
  512. raise TypeError("Only numeric dtypes are allowed in square")
  513. return Array._new(np.square(x._array))
  514. def sqrt(x: Array, /) -> Array:
  515. """
  516. Array API compatible wrapper for :py:func:`np.sqrt <numpy.sqrt>`.
  517. See its docstring for more information.
  518. """
  519. if x.dtype not in _floating_dtypes:
  520. raise TypeError("Only floating-point dtypes are allowed in sqrt")
  521. return Array._new(np.sqrt(x._array))
  522. def subtract(x1: Array, x2: Array, /) -> Array:
  523. """
  524. Array API compatible wrapper for :py:func:`np.subtract <numpy.subtract>`.
  525. See its docstring for more information.
  526. """
  527. if x1.dtype not in _numeric_dtypes or x2.dtype not in _numeric_dtypes:
  528. raise TypeError("Only numeric dtypes are allowed in subtract")
  529. # Call result type here just to raise on disallowed type combinations
  530. _result_type(x1.dtype, x2.dtype)
  531. x1, x2 = Array._normalize_two_args(x1, x2)
  532. return Array._new(np.subtract(x1._array, x2._array))
  533. def tan(x: Array, /) -> Array:
  534. """
  535. Array API compatible wrapper for :py:func:`np.tan <numpy.tan>`.
  536. See its docstring for more information.
  537. """
  538. if x.dtype not in _floating_dtypes:
  539. raise TypeError("Only floating-point dtypes are allowed in tan")
  540. return Array._new(np.tan(x._array))
  541. def tanh(x: Array, /) -> Array:
  542. """
  543. Array API compatible wrapper for :py:func:`np.tanh <numpy.tanh>`.
  544. See its docstring for more information.
  545. """
  546. if x.dtype not in _floating_dtypes:
  547. raise TypeError("Only floating-point dtypes are allowed in tanh")
  548. return Array._new(np.tanh(x._array))
  549. def trunc(x: Array, /) -> Array:
  550. """
  551. Array API compatible wrapper for :py:func:`np.trunc <numpy.trunc>`.
  552. See its docstring for more information.
  553. """
  554. if x.dtype not in _numeric_dtypes:
  555. raise TypeError("Only numeric dtypes are allowed in trunc")
  556. if x.dtype in _integer_dtypes:
  557. # Note: The return dtype of trunc is the same as the input
  558. return x
  559. return Array._new(np.trunc(x._array))