123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466 |
- """
- Functions for changing global ufunc configuration
- This provides helpers which wrap `umath.geterrobj` and `umath.seterrobj`
- """
- import collections.abc
- import contextlib
- import contextvars
- from .overrides import set_module
- from .umath import (
- UFUNC_BUFSIZE_DEFAULT,
- ERR_IGNORE, ERR_WARN, ERR_RAISE, ERR_CALL, ERR_PRINT, ERR_LOG, ERR_DEFAULT,
- SHIFT_DIVIDEBYZERO, SHIFT_OVERFLOW, SHIFT_UNDERFLOW, SHIFT_INVALID,
- )
- from . import umath
- __all__ = [
- "seterr", "geterr", "setbufsize", "getbufsize", "seterrcall", "geterrcall",
- "errstate", '_no_nep50_warning'
- ]
- _errdict = {"ignore": ERR_IGNORE,
- "warn": ERR_WARN,
- "raise": ERR_RAISE,
- "call": ERR_CALL,
- "print": ERR_PRINT,
- "log": ERR_LOG}
- _errdict_rev = {value: key for key, value in _errdict.items()}
- @set_module('numpy')
- def seterr(all=None, divide=None, over=None, under=None, invalid=None):
- """
- Set how floating-point errors are handled.
- Note that operations on integer scalar types (such as `int16`) are
- handled like floating point, and are affected by these settings.
- Parameters
- ----------
- all : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
- Set treatment for all types of floating-point errors at once:
- - ignore: Take no action when the exception occurs.
- - warn: Print a `RuntimeWarning` (via the Python `warnings` module).
- - raise: Raise a `FloatingPointError`.
- - call: Call a function specified using the `seterrcall` function.
- - print: Print a warning directly to ``stdout``.
- - log: Record error in a Log object specified by `seterrcall`.
- The default is not to change the current behavior.
- divide : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
- Treatment for division by zero.
- over : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
- Treatment for floating-point overflow.
- under : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
- Treatment for floating-point underflow.
- invalid : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
- Treatment for invalid floating-point operation.
- Returns
- -------
- old_settings : dict
- Dictionary containing the old settings.
- See also
- --------
- seterrcall : Set a callback function for the 'call' mode.
- geterr, geterrcall, errstate
- Notes
- -----
- The floating-point exceptions are defined in the IEEE 754 standard [1]_:
- - Division by zero: infinite result obtained from finite numbers.
- - Overflow: result too large to be expressed.
- - Underflow: result so close to zero that some precision
- was lost.
- - Invalid operation: result is not an expressible number, typically
- indicates that a NaN was produced.
- .. [1] https://en.wikipedia.org/wiki/IEEE_754
- Examples
- --------
- >>> old_settings = np.seterr(all='ignore') #seterr to known value
- >>> np.seterr(over='raise')
- {'divide': 'ignore', 'over': 'ignore', 'under': 'ignore', 'invalid': 'ignore'}
- >>> np.seterr(**old_settings) # reset to default
- {'divide': 'ignore', 'over': 'raise', 'under': 'ignore', 'invalid': 'ignore'}
- >>> np.int16(32000) * np.int16(3)
- 30464
- >>> old_settings = np.seterr(all='warn', over='raise')
- >>> np.int16(32000) * np.int16(3)
- Traceback (most recent call last):
- File "<stdin>", line 1, in <module>
- FloatingPointError: overflow encountered in scalar multiply
- >>> old_settings = np.seterr(all='print')
- >>> np.geterr()
- {'divide': 'print', 'over': 'print', 'under': 'print', 'invalid': 'print'}
- >>> np.int16(32000) * np.int16(3)
- 30464
- """
- pyvals = umath.geterrobj()
- old = geterr()
- if divide is None:
- divide = all or old['divide']
- if over is None:
- over = all or old['over']
- if under is None:
- under = all or old['under']
- if invalid is None:
- invalid = all or old['invalid']
- maskvalue = ((_errdict[divide] << SHIFT_DIVIDEBYZERO) +
- (_errdict[over] << SHIFT_OVERFLOW) +
- (_errdict[under] << SHIFT_UNDERFLOW) +
- (_errdict[invalid] << SHIFT_INVALID))
- pyvals[1] = maskvalue
- umath.seterrobj(pyvals)
- return old
- @set_module('numpy')
- def geterr():
- """
- Get the current way of handling floating-point errors.
- Returns
- -------
- res : dict
- A dictionary with keys "divide", "over", "under", and "invalid",
- whose values are from the strings "ignore", "print", "log", "warn",
- "raise", and "call". The keys represent possible floating-point
- exceptions, and the values define how these exceptions are handled.
- See Also
- --------
- geterrcall, seterr, seterrcall
- Notes
- -----
- For complete documentation of the types of floating-point exceptions and
- treatment options, see `seterr`.
- Examples
- --------
- >>> np.geterr()
- {'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}
- >>> np.arange(3.) / np.arange(3.)
- array([nan, 1., 1.])
- >>> oldsettings = np.seterr(all='warn', over='raise')
- >>> np.geterr()
- {'divide': 'warn', 'over': 'raise', 'under': 'warn', 'invalid': 'warn'}
- >>> np.arange(3.) / np.arange(3.)
- array([nan, 1., 1.])
- """
- maskvalue = umath.geterrobj()[1]
- mask = 7
- res = {}
- val = (maskvalue >> SHIFT_DIVIDEBYZERO) & mask
- res['divide'] = _errdict_rev[val]
- val = (maskvalue >> SHIFT_OVERFLOW) & mask
- res['over'] = _errdict_rev[val]
- val = (maskvalue >> SHIFT_UNDERFLOW) & mask
- res['under'] = _errdict_rev[val]
- val = (maskvalue >> SHIFT_INVALID) & mask
- res['invalid'] = _errdict_rev[val]
- return res
- @set_module('numpy')
- def setbufsize(size):
- """
- Set the size of the buffer used in ufuncs.
- Parameters
- ----------
- size : int
- Size of buffer.
- """
- if size > 10e6:
- raise ValueError("Buffer size, %s, is too big." % size)
- if size < 5:
- raise ValueError("Buffer size, %s, is too small." % size)
- if size % 16 != 0:
- raise ValueError("Buffer size, %s, is not a multiple of 16." % size)
- pyvals = umath.geterrobj()
- old = getbufsize()
- pyvals[0] = size
- umath.seterrobj(pyvals)
- return old
- @set_module('numpy')
- def getbufsize():
- """
- Return the size of the buffer used in ufuncs.
- Returns
- -------
- getbufsize : int
- Size of ufunc buffer in bytes.
- """
- return umath.geterrobj()[0]
- @set_module('numpy')
- def seterrcall(func):
- """
- Set the floating-point error callback function or log object.
- There are two ways to capture floating-point error messages. The first
- is to set the error-handler to 'call', using `seterr`. Then, set
- the function to call using this function.
- The second is to set the error-handler to 'log', using `seterr`.
- Floating-point errors then trigger a call to the 'write' method of
- the provided object.
- Parameters
- ----------
- func : callable f(err, flag) or object with write method
- Function to call upon floating-point errors ('call'-mode) or
- object whose 'write' method is used to log such message ('log'-mode).
- The call function takes two arguments. The first is a string describing
- the type of error (such as "divide by zero", "overflow", "underflow",
- or "invalid value"), and the second is the status flag. The flag is a
- byte, whose four least-significant bits indicate the type of error, one
- of "divide", "over", "under", "invalid"::
- [0 0 0 0 divide over under invalid]
- In other words, ``flags = divide + 2*over + 4*under + 8*invalid``.
- If an object is provided, its write method should take one argument,
- a string.
- Returns
- -------
- h : callable, log instance or None
- The old error handler.
- See Also
- --------
- seterr, geterr, geterrcall
- Examples
- --------
- Callback upon error:
- >>> def err_handler(type, flag):
- ... print("Floating point error (%s), with flag %s" % (type, flag))
- ...
- >>> saved_handler = np.seterrcall(err_handler)
- >>> save_err = np.seterr(all='call')
- >>> np.array([1, 2, 3]) / 0.0
- Floating point error (divide by zero), with flag 1
- array([inf, inf, inf])
- >>> np.seterrcall(saved_handler)
- <function err_handler at 0x...>
- >>> np.seterr(**save_err)
- {'divide': 'call', 'over': 'call', 'under': 'call', 'invalid': 'call'}
- Log error message:
- >>> class Log:
- ... def write(self, msg):
- ... print("LOG: %s" % msg)
- ...
- >>> log = Log()
- >>> saved_handler = np.seterrcall(log)
- >>> save_err = np.seterr(all='log')
- >>> np.array([1, 2, 3]) / 0.0
- LOG: Warning: divide by zero encountered in divide
- array([inf, inf, inf])
- >>> np.seterrcall(saved_handler)
- <numpy.core.numeric.Log object at 0x...>
- >>> np.seterr(**save_err)
- {'divide': 'log', 'over': 'log', 'under': 'log', 'invalid': 'log'}
- """
- if func is not None and not isinstance(func, collections.abc.Callable):
- if (not hasattr(func, 'write') or
- not isinstance(func.write, collections.abc.Callable)):
- raise ValueError("Only callable can be used as callback")
- pyvals = umath.geterrobj()
- old = geterrcall()
- pyvals[2] = func
- umath.seterrobj(pyvals)
- return old
- @set_module('numpy')
- def geterrcall():
- """
- Return the current callback function used on floating-point errors.
- When the error handling for a floating-point error (one of "divide",
- "over", "under", or "invalid") is set to 'call' or 'log', the function
- that is called or the log instance that is written to is returned by
- `geterrcall`. This function or log instance has been set with
- `seterrcall`.
- Returns
- -------
- errobj : callable, log instance or None
- The current error handler. If no handler was set through `seterrcall`,
- ``None`` is returned.
- See Also
- --------
- seterrcall, seterr, geterr
- Notes
- -----
- For complete documentation of the types of floating-point exceptions and
- treatment options, see `seterr`.
- Examples
- --------
- >>> np.geterrcall() # we did not yet set a handler, returns None
- >>> oldsettings = np.seterr(all='call')
- >>> def err_handler(type, flag):
- ... print("Floating point error (%s), with flag %s" % (type, flag))
- >>> oldhandler = np.seterrcall(err_handler)
- >>> np.array([1, 2, 3]) / 0.0
- Floating point error (divide by zero), with flag 1
- array([inf, inf, inf])
- >>> cur_handler = np.geterrcall()
- >>> cur_handler is err_handler
- True
- """
- return umath.geterrobj()[2]
- class _unspecified:
- pass
- _Unspecified = _unspecified()
- @set_module('numpy')
- class errstate(contextlib.ContextDecorator):
- """
- errstate(**kwargs)
- Context manager for floating-point error handling.
- Using an instance of `errstate` as a context manager allows statements in
- that context to execute with a known error handling behavior. Upon entering
- the context the error handling is set with `seterr` and `seterrcall`, and
- upon exiting it is reset to what it was before.
- .. versionchanged:: 1.17.0
- `errstate` is also usable as a function decorator, saving
- a level of indentation if an entire function is wrapped.
- See :py:class:`contextlib.ContextDecorator` for more information.
- Parameters
- ----------
- kwargs : {divide, over, under, invalid}
- Keyword arguments. The valid keywords are the possible floating-point
- exceptions. Each keyword should have a string value that defines the
- treatment for the particular error. Possible values are
- {'ignore', 'warn', 'raise', 'call', 'print', 'log'}.
- See Also
- --------
- seterr, geterr, seterrcall, geterrcall
- Notes
- -----
- For complete documentation of the types of floating-point exceptions and
- treatment options, see `seterr`.
- Examples
- --------
- >>> olderr = np.seterr(all='ignore') # Set error handling to known state.
- >>> np.arange(3) / 0.
- array([nan, inf, inf])
- >>> with np.errstate(divide='warn'):
- ... np.arange(3) / 0.
- array([nan, inf, inf])
- >>> np.sqrt(-1)
- nan
- >>> with np.errstate(invalid='raise'):
- ... np.sqrt(-1)
- Traceback (most recent call last):
- File "<stdin>", line 2, in <module>
- FloatingPointError: invalid value encountered in sqrt
- Outside the context the error handling behavior has not changed:
- >>> np.geterr()
- {'divide': 'ignore', 'over': 'ignore', 'under': 'ignore', 'invalid': 'ignore'}
- """
- def __init__(self, *, call=_Unspecified, **kwargs):
- self.call = call
- self.kwargs = kwargs
- def __enter__(self):
- self.oldstate = seterr(**self.kwargs)
- if self.call is not _Unspecified:
- self.oldcall = seterrcall(self.call)
- def __exit__(self, *exc_info):
- seterr(**self.oldstate)
- if self.call is not _Unspecified:
- seterrcall(self.oldcall)
- def _setdef():
- defval = [UFUNC_BUFSIZE_DEFAULT, ERR_DEFAULT, None]
- umath.seterrobj(defval)
- # set the default values
- _setdef()
- NO_NEP50_WARNING = contextvars.ContextVar("_no_nep50_warning", default=False)
- @set_module('numpy')
- @contextlib.contextmanager
- def _no_nep50_warning():
- """
- Context manager to disable NEP 50 warnings. This context manager is
- only relevant if the NEP 50 warnings are enabled globally (which is not
- thread/context safe).
- This warning context manager itself is fully safe, however.
- """
- token = NO_NEP50_WARNING.set(True)
- try:
- yield
- finally:
- NO_NEP50_WARNING.reset(token)
|