overrides.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. """Implementation of __array_function__ overrides from NEP-18."""
  2. import collections
  3. import functools
  4. import os
  5. from numpy.core._multiarray_umath import (
  6. add_docstring, implement_array_function, _get_implementing_args)
  7. from numpy.compat._inspect import getargspec
  8. ARRAY_FUNCTION_ENABLED = bool(
  9. int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1)))
  10. array_function_like_doc = (
  11. """like : array_like, optional
  12. Reference object to allow the creation of arrays which are not
  13. NumPy arrays. If an array-like passed in as ``like`` supports
  14. the ``__array_function__`` protocol, the result will be defined
  15. by it. In this case, it ensures the creation of an array object
  16. compatible with that passed in via this argument."""
  17. )
  18. def set_array_function_like_doc(public_api):
  19. if public_api.__doc__ is not None:
  20. public_api.__doc__ = public_api.__doc__.replace(
  21. "${ARRAY_FUNCTION_LIKE}",
  22. array_function_like_doc,
  23. )
  24. return public_api
  25. add_docstring(
  26. implement_array_function,
  27. """
  28. Implement a function with checks for __array_function__ overrides.
  29. All arguments are required, and can only be passed by position.
  30. Parameters
  31. ----------
  32. implementation : function
  33. Function that implements the operation on NumPy array without
  34. overrides when called like ``implementation(*args, **kwargs)``.
  35. public_api : function
  36. Function exposed by NumPy's public API originally called like
  37. ``public_api(*args, **kwargs)`` on which arguments are now being
  38. checked.
  39. relevant_args : iterable
  40. Iterable of arguments to check for __array_function__ methods.
  41. args : tuple
  42. Arbitrary positional arguments originally passed into ``public_api``.
  43. kwargs : dict
  44. Arbitrary keyword arguments originally passed into ``public_api``.
  45. Returns
  46. -------
  47. Result from calling ``implementation()`` or an ``__array_function__``
  48. method, as appropriate.
  49. Raises
  50. ------
  51. TypeError : if no implementation is found.
  52. """)
  53. # exposed for testing purposes; used internally by implement_array_function
  54. add_docstring(
  55. _get_implementing_args,
  56. """
  57. Collect arguments on which to call __array_function__.
  58. Parameters
  59. ----------
  60. relevant_args : iterable of array-like
  61. Iterable of possibly array-like arguments to check for
  62. __array_function__ methods.
  63. Returns
  64. -------
  65. Sequence of arguments with __array_function__ methods, in the order in
  66. which they should be called.
  67. """)
  68. ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
  69. def verify_matching_signatures(implementation, dispatcher):
  70. """Verify that a dispatcher function has the right signature."""
  71. implementation_spec = ArgSpec(*getargspec(implementation))
  72. dispatcher_spec = ArgSpec(*getargspec(dispatcher))
  73. if (implementation_spec.args != dispatcher_spec.args or
  74. implementation_spec.varargs != dispatcher_spec.varargs or
  75. implementation_spec.keywords != dispatcher_spec.keywords or
  76. (bool(implementation_spec.defaults) !=
  77. bool(dispatcher_spec.defaults)) or
  78. (implementation_spec.defaults is not None and
  79. len(implementation_spec.defaults) !=
  80. len(dispatcher_spec.defaults))):
  81. raise RuntimeError('implementation and dispatcher for %s have '
  82. 'different function signatures' % implementation)
  83. if implementation_spec.defaults is not None:
  84. if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
  85. raise RuntimeError('dispatcher functions can only use None for '
  86. 'default argument values')
  87. def set_module(module):
  88. """Decorator for overriding __module__ on a function or class.
  89. Example usage::
  90. @set_module('numpy')
  91. def example():
  92. pass
  93. assert example.__module__ == 'numpy'
  94. """
  95. def decorator(func):
  96. if module is not None:
  97. func.__module__ = module
  98. return func
  99. return decorator
  100. def array_function_dispatch(dispatcher, module=None, verify=True,
  101. docs_from_dispatcher=False, use_like=False):
  102. """Decorator for adding dispatch with the __array_function__ protocol.
  103. See NEP-18 for example usage.
  104. Parameters
  105. ----------
  106. dispatcher : callable
  107. Function that when called like ``dispatcher(*args, **kwargs)`` with
  108. arguments from the NumPy function call returns an iterable of
  109. array-like arguments to check for ``__array_function__``.
  110. module : str, optional
  111. __module__ attribute to set on new function, e.g., ``module='numpy'``.
  112. By default, module is copied from the decorated function.
  113. verify : bool, optional
  114. If True, verify the that the signature of the dispatcher and decorated
  115. function signatures match exactly: all required and optional arguments
  116. should appear in order with the same names, but the default values for
  117. all optional arguments should be ``None``. Only disable verification
  118. if the dispatcher's signature needs to deviate for some particular
  119. reason, e.g., because the function has a signature like
  120. ``func(*args, **kwargs)``.
  121. docs_from_dispatcher : bool, optional
  122. If True, copy docs from the dispatcher function onto the dispatched
  123. function, rather than from the implementation. This is useful for
  124. functions defined in C, which otherwise don't have docstrings.
  125. Returns
  126. -------
  127. Function suitable for decorating the implementation of a NumPy function.
  128. """
  129. if not ARRAY_FUNCTION_ENABLED:
  130. def decorator(implementation):
  131. if docs_from_dispatcher:
  132. add_docstring(implementation, dispatcher.__doc__)
  133. if module is not None:
  134. implementation.__module__ = module
  135. return implementation
  136. return decorator
  137. def decorator(implementation):
  138. if verify:
  139. verify_matching_signatures(implementation, dispatcher)
  140. if docs_from_dispatcher:
  141. add_docstring(implementation, dispatcher.__doc__)
  142. @functools.wraps(implementation)
  143. def public_api(*args, **kwargs):
  144. try:
  145. relevant_args = dispatcher(*args, **kwargs)
  146. except TypeError as exc:
  147. # Try to clean up a signature related TypeError. Such an
  148. # error will be something like:
  149. # dispatcher.__name__() got an unexpected keyword argument
  150. #
  151. # So replace the dispatcher name in this case. In principle
  152. # TypeErrors may be raised from _within_ the dispatcher, so
  153. # we check that the traceback contains a string that starts
  154. # with the name. (In principle we could also check the
  155. # traceback length, as it would be deeper.)
  156. msg = exc.args[0]
  157. disp_name = dispatcher.__name__
  158. if not isinstance(msg, str) or not msg.startswith(disp_name):
  159. raise
  160. # Replace with the correct name and re-raise:
  161. new_msg = msg.replace(disp_name, public_api.__name__)
  162. raise TypeError(new_msg) from None
  163. return implement_array_function(
  164. implementation, public_api, relevant_args, args, kwargs,
  165. use_like)
  166. public_api.__code__ = public_api.__code__.replace(
  167. co_name=implementation.__name__,
  168. co_filename='<__array_function__ internals>')
  169. if module is not None:
  170. public_api.__module__ = module
  171. public_api._implementation = implementation
  172. return public_api
  173. return decorator
  174. def array_function_from_dispatcher(
  175. implementation, module=None, verify=True, docs_from_dispatcher=True):
  176. """Like array_function_dispatcher, but with function arguments flipped."""
  177. def decorator(dispatcher):
  178. return array_function_dispatch(
  179. dispatcher, module, verify=verify,
  180. docs_from_dispatcher=docs_from_dispatcher)(implementation)
  181. return decorator