_ufunc_config.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. """
  2. Functions for changing global ufunc configuration
  3. This provides helpers which wrap `umath.geterrobj` and `umath.seterrobj`
  4. """
  5. import collections.abc
  6. import contextlib
  7. import contextvars
  8. from .overrides import set_module
  9. from .umath import (
  10. UFUNC_BUFSIZE_DEFAULT,
  11. ERR_IGNORE, ERR_WARN, ERR_RAISE, ERR_CALL, ERR_PRINT, ERR_LOG, ERR_DEFAULT,
  12. SHIFT_DIVIDEBYZERO, SHIFT_OVERFLOW, SHIFT_UNDERFLOW, SHIFT_INVALID,
  13. )
  14. from . import umath
  15. __all__ = [
  16. "seterr", "geterr", "setbufsize", "getbufsize", "seterrcall", "geterrcall",
  17. "errstate", '_no_nep50_warning'
  18. ]
  19. _errdict = {"ignore": ERR_IGNORE,
  20. "warn": ERR_WARN,
  21. "raise": ERR_RAISE,
  22. "call": ERR_CALL,
  23. "print": ERR_PRINT,
  24. "log": ERR_LOG}
  25. _errdict_rev = {value: key for key, value in _errdict.items()}
  26. @set_module('numpy')
  27. def seterr(all=None, divide=None, over=None, under=None, invalid=None):
  28. """
  29. Set how floating-point errors are handled.
  30. Note that operations on integer scalar types (such as `int16`) are
  31. handled like floating point, and are affected by these settings.
  32. Parameters
  33. ----------
  34. all : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
  35. Set treatment for all types of floating-point errors at once:
  36. - ignore: Take no action when the exception occurs.
  37. - warn: Print a `RuntimeWarning` (via the Python `warnings` module).
  38. - raise: Raise a `FloatingPointError`.
  39. - call: Call a function specified using the `seterrcall` function.
  40. - print: Print a warning directly to ``stdout``.
  41. - log: Record error in a Log object specified by `seterrcall`.
  42. The default is not to change the current behavior.
  43. divide : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
  44. Treatment for division by zero.
  45. over : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
  46. Treatment for floating-point overflow.
  47. under : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
  48. Treatment for floating-point underflow.
  49. invalid : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
  50. Treatment for invalid floating-point operation.
  51. Returns
  52. -------
  53. old_settings : dict
  54. Dictionary containing the old settings.
  55. See also
  56. --------
  57. seterrcall : Set a callback function for the 'call' mode.
  58. geterr, geterrcall, errstate
  59. Notes
  60. -----
  61. The floating-point exceptions are defined in the IEEE 754 standard [1]_:
  62. - Division by zero: infinite result obtained from finite numbers.
  63. - Overflow: result too large to be expressed.
  64. - Underflow: result so close to zero that some precision
  65. was lost.
  66. - Invalid operation: result is not an expressible number, typically
  67. indicates that a NaN was produced.
  68. .. [1] https://en.wikipedia.org/wiki/IEEE_754
  69. Examples
  70. --------
  71. >>> old_settings = np.seterr(all='ignore') #seterr to known value
  72. >>> np.seterr(over='raise')
  73. {'divide': 'ignore', 'over': 'ignore', 'under': 'ignore', 'invalid': 'ignore'}
  74. >>> np.seterr(**old_settings) # reset to default
  75. {'divide': 'ignore', 'over': 'raise', 'under': 'ignore', 'invalid': 'ignore'}
  76. >>> np.int16(32000) * np.int16(3)
  77. 30464
  78. >>> old_settings = np.seterr(all='warn', over='raise')
  79. >>> np.int16(32000) * np.int16(3)
  80. Traceback (most recent call last):
  81. File "<stdin>", line 1, in <module>
  82. FloatingPointError: overflow encountered in scalar multiply
  83. >>> old_settings = np.seterr(all='print')
  84. >>> np.geterr()
  85. {'divide': 'print', 'over': 'print', 'under': 'print', 'invalid': 'print'}
  86. >>> np.int16(32000) * np.int16(3)
  87. 30464
  88. """
  89. pyvals = umath.geterrobj()
  90. old = geterr()
  91. if divide is None:
  92. divide = all or old['divide']
  93. if over is None:
  94. over = all or old['over']
  95. if under is None:
  96. under = all or old['under']
  97. if invalid is None:
  98. invalid = all or old['invalid']
  99. maskvalue = ((_errdict[divide] << SHIFT_DIVIDEBYZERO) +
  100. (_errdict[over] << SHIFT_OVERFLOW) +
  101. (_errdict[under] << SHIFT_UNDERFLOW) +
  102. (_errdict[invalid] << SHIFT_INVALID))
  103. pyvals[1] = maskvalue
  104. umath.seterrobj(pyvals)
  105. return old
  106. @set_module('numpy')
  107. def geterr():
  108. """
  109. Get the current way of handling floating-point errors.
  110. Returns
  111. -------
  112. res : dict
  113. A dictionary with keys "divide", "over", "under", and "invalid",
  114. whose values are from the strings "ignore", "print", "log", "warn",
  115. "raise", and "call". The keys represent possible floating-point
  116. exceptions, and the values define how these exceptions are handled.
  117. See Also
  118. --------
  119. geterrcall, seterr, seterrcall
  120. Notes
  121. -----
  122. For complete documentation of the types of floating-point exceptions and
  123. treatment options, see `seterr`.
  124. Examples
  125. --------
  126. >>> np.geterr()
  127. {'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}
  128. >>> np.arange(3.) / np.arange(3.)
  129. array([nan, 1., 1.])
  130. >>> oldsettings = np.seterr(all='warn', over='raise')
  131. >>> np.geterr()
  132. {'divide': 'warn', 'over': 'raise', 'under': 'warn', 'invalid': 'warn'}
  133. >>> np.arange(3.) / np.arange(3.)
  134. array([nan, 1., 1.])
  135. """
  136. maskvalue = umath.geterrobj()[1]
  137. mask = 7
  138. res = {}
  139. val = (maskvalue >> SHIFT_DIVIDEBYZERO) & mask
  140. res['divide'] = _errdict_rev[val]
  141. val = (maskvalue >> SHIFT_OVERFLOW) & mask
  142. res['over'] = _errdict_rev[val]
  143. val = (maskvalue >> SHIFT_UNDERFLOW) & mask
  144. res['under'] = _errdict_rev[val]
  145. val = (maskvalue >> SHIFT_INVALID) & mask
  146. res['invalid'] = _errdict_rev[val]
  147. return res
  148. @set_module('numpy')
  149. def setbufsize(size):
  150. """
  151. Set the size of the buffer used in ufuncs.
  152. Parameters
  153. ----------
  154. size : int
  155. Size of buffer.
  156. """
  157. if size > 10e6:
  158. raise ValueError("Buffer size, %s, is too big." % size)
  159. if size < 5:
  160. raise ValueError("Buffer size, %s, is too small." % size)
  161. if size % 16 != 0:
  162. raise ValueError("Buffer size, %s, is not a multiple of 16." % size)
  163. pyvals = umath.geterrobj()
  164. old = getbufsize()
  165. pyvals[0] = size
  166. umath.seterrobj(pyvals)
  167. return old
  168. @set_module('numpy')
  169. def getbufsize():
  170. """
  171. Return the size of the buffer used in ufuncs.
  172. Returns
  173. -------
  174. getbufsize : int
  175. Size of ufunc buffer in bytes.
  176. """
  177. return umath.geterrobj()[0]
  178. @set_module('numpy')
  179. def seterrcall(func):
  180. """
  181. Set the floating-point error callback function or log object.
  182. There are two ways to capture floating-point error messages. The first
  183. is to set the error-handler to 'call', using `seterr`. Then, set
  184. the function to call using this function.
  185. The second is to set the error-handler to 'log', using `seterr`.
  186. Floating-point errors then trigger a call to the 'write' method of
  187. the provided object.
  188. Parameters
  189. ----------
  190. func : callable f(err, flag) or object with write method
  191. Function to call upon floating-point errors ('call'-mode) or
  192. object whose 'write' method is used to log such message ('log'-mode).
  193. The call function takes two arguments. The first is a string describing
  194. the type of error (such as "divide by zero", "overflow", "underflow",
  195. or "invalid value"), and the second is the status flag. The flag is a
  196. byte, whose four least-significant bits indicate the type of error, one
  197. of "divide", "over", "under", "invalid"::
  198. [0 0 0 0 divide over under invalid]
  199. In other words, ``flags = divide + 2*over + 4*under + 8*invalid``.
  200. If an object is provided, its write method should take one argument,
  201. a string.
  202. Returns
  203. -------
  204. h : callable, log instance or None
  205. The old error handler.
  206. See Also
  207. --------
  208. seterr, geterr, geterrcall
  209. Examples
  210. --------
  211. Callback upon error:
  212. >>> def err_handler(type, flag):
  213. ... print("Floating point error (%s), with flag %s" % (type, flag))
  214. ...
  215. >>> saved_handler = np.seterrcall(err_handler)
  216. >>> save_err = np.seterr(all='call')
  217. >>> np.array([1, 2, 3]) / 0.0
  218. Floating point error (divide by zero), with flag 1
  219. array([inf, inf, inf])
  220. >>> np.seterrcall(saved_handler)
  221. <function err_handler at 0x...>
  222. >>> np.seterr(**save_err)
  223. {'divide': 'call', 'over': 'call', 'under': 'call', 'invalid': 'call'}
  224. Log error message:
  225. >>> class Log:
  226. ... def write(self, msg):
  227. ... print("LOG: %s" % msg)
  228. ...
  229. >>> log = Log()
  230. >>> saved_handler = np.seterrcall(log)
  231. >>> save_err = np.seterr(all='log')
  232. >>> np.array([1, 2, 3]) / 0.0
  233. LOG: Warning: divide by zero encountered in divide
  234. array([inf, inf, inf])
  235. >>> np.seterrcall(saved_handler)
  236. <numpy.core.numeric.Log object at 0x...>
  237. >>> np.seterr(**save_err)
  238. {'divide': 'log', 'over': 'log', 'under': 'log', 'invalid': 'log'}
  239. """
  240. if func is not None and not isinstance(func, collections.abc.Callable):
  241. if (not hasattr(func, 'write') or
  242. not isinstance(func.write, collections.abc.Callable)):
  243. raise ValueError("Only callable can be used as callback")
  244. pyvals = umath.geterrobj()
  245. old = geterrcall()
  246. pyvals[2] = func
  247. umath.seterrobj(pyvals)
  248. return old
  249. @set_module('numpy')
  250. def geterrcall():
  251. """
  252. Return the current callback function used on floating-point errors.
  253. When the error handling for a floating-point error (one of "divide",
  254. "over", "under", or "invalid") is set to 'call' or 'log', the function
  255. that is called or the log instance that is written to is returned by
  256. `geterrcall`. This function or log instance has been set with
  257. `seterrcall`.
  258. Returns
  259. -------
  260. errobj : callable, log instance or None
  261. The current error handler. If no handler was set through `seterrcall`,
  262. ``None`` is returned.
  263. See Also
  264. --------
  265. seterrcall, seterr, geterr
  266. Notes
  267. -----
  268. For complete documentation of the types of floating-point exceptions and
  269. treatment options, see `seterr`.
  270. Examples
  271. --------
  272. >>> np.geterrcall() # we did not yet set a handler, returns None
  273. >>> oldsettings = np.seterr(all='call')
  274. >>> def err_handler(type, flag):
  275. ... print("Floating point error (%s), with flag %s" % (type, flag))
  276. >>> oldhandler = np.seterrcall(err_handler)
  277. >>> np.array([1, 2, 3]) / 0.0
  278. Floating point error (divide by zero), with flag 1
  279. array([inf, inf, inf])
  280. >>> cur_handler = np.geterrcall()
  281. >>> cur_handler is err_handler
  282. True
  283. """
  284. return umath.geterrobj()[2]
  285. class _unspecified:
  286. pass
  287. _Unspecified = _unspecified()
  288. @set_module('numpy')
  289. class errstate(contextlib.ContextDecorator):
  290. """
  291. errstate(**kwargs)
  292. Context manager for floating-point error handling.
  293. Using an instance of `errstate` as a context manager allows statements in
  294. that context to execute with a known error handling behavior. Upon entering
  295. the context the error handling is set with `seterr` and `seterrcall`, and
  296. upon exiting it is reset to what it was before.
  297. .. versionchanged:: 1.17.0
  298. `errstate` is also usable as a function decorator, saving
  299. a level of indentation if an entire function is wrapped.
  300. See :py:class:`contextlib.ContextDecorator` for more information.
  301. Parameters
  302. ----------
  303. kwargs : {divide, over, under, invalid}
  304. Keyword arguments. The valid keywords are the possible floating-point
  305. exceptions. Each keyword should have a string value that defines the
  306. treatment for the particular error. Possible values are
  307. {'ignore', 'warn', 'raise', 'call', 'print', 'log'}.
  308. See Also
  309. --------
  310. seterr, geterr, seterrcall, geterrcall
  311. Notes
  312. -----
  313. For complete documentation of the types of floating-point exceptions and
  314. treatment options, see `seterr`.
  315. Examples
  316. --------
  317. >>> olderr = np.seterr(all='ignore') # Set error handling to known state.
  318. >>> np.arange(3) / 0.
  319. array([nan, inf, inf])
  320. >>> with np.errstate(divide='warn'):
  321. ... np.arange(3) / 0.
  322. array([nan, inf, inf])
  323. >>> np.sqrt(-1)
  324. nan
  325. >>> with np.errstate(invalid='raise'):
  326. ... np.sqrt(-1)
  327. Traceback (most recent call last):
  328. File "<stdin>", line 2, in <module>
  329. FloatingPointError: invalid value encountered in sqrt
  330. Outside the context the error handling behavior has not changed:
  331. >>> np.geterr()
  332. {'divide': 'ignore', 'over': 'ignore', 'under': 'ignore', 'invalid': 'ignore'}
  333. """
  334. def __init__(self, *, call=_Unspecified, **kwargs):
  335. self.call = call
  336. self.kwargs = kwargs
  337. def __enter__(self):
  338. self.oldstate = seterr(**self.kwargs)
  339. if self.call is not _Unspecified:
  340. self.oldcall = seterrcall(self.call)
  341. def __exit__(self, *exc_info):
  342. seterr(**self.oldstate)
  343. if self.call is not _Unspecified:
  344. seterrcall(self.oldcall)
  345. def _setdef():
  346. defval = [UFUNC_BUFSIZE_DEFAULT, ERR_DEFAULT, None]
  347. umath.seterrobj(defval)
  348. # set the default values
  349. _setdef()
  350. NO_NEP50_WARNING = contextvars.ContextVar("_no_nep50_warning", default=False)
  351. @set_module('numpy')
  352. @contextlib.contextmanager
  353. def _no_nep50_warning():
  354. """
  355. Context manager to disable NEP 50 warnings. This context manager is
  356. only relevant if the NEP 50 warnings are enabled globally (which is not
  357. thread/context safe).
  358. This warning context manager itself is fully safe, however.
  359. """
  360. token = NO_NEP50_WARNING.set(True)
  361. try:
  362. yield
  363. finally:
  364. NO_NEP50_WARNING.reset(token)