_backend.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703
  1. import typing
  2. import types
  3. import inspect
  4. import functools
  5. from . import _uarray
  6. import copyreg
  7. import pickle
  8. import contextlib
  9. ArgumentExtractorType = typing.Callable[..., typing.Tuple["Dispatchable", ...]]
  10. ArgumentReplacerType = typing.Callable[
  11. [typing.Tuple, typing.Dict, typing.Tuple], typing.Tuple[typing.Tuple, typing.Dict]
  12. ]
  13. from ._uarray import ( # type: ignore
  14. BackendNotImplementedError,
  15. _Function,
  16. _SkipBackendContext,
  17. _SetBackendContext,
  18. _BackendState,
  19. )
  20. __all__ = [
  21. "set_backend",
  22. "set_global_backend",
  23. "skip_backend",
  24. "register_backend",
  25. "determine_backend",
  26. "determine_backend_multi",
  27. "clear_backends",
  28. "create_multimethod",
  29. "generate_multimethod",
  30. "_Function",
  31. "BackendNotImplementedError",
  32. "Dispatchable",
  33. "wrap_single_convertor",
  34. "wrap_single_convertor_instance",
  35. "all_of_type",
  36. "mark_as",
  37. "set_state",
  38. "get_state",
  39. "reset_state",
  40. "_BackendState",
  41. "_SkipBackendContext",
  42. "_SetBackendContext",
  43. ]
  44. def unpickle_function(mod_name, qname, self_):
  45. import importlib
  46. try:
  47. module = importlib.import_module(mod_name)
  48. qname = qname.split(".")
  49. func = module
  50. for q in qname:
  51. func = getattr(func, q)
  52. if self_ is not None:
  53. func = types.MethodType(func, self_)
  54. return func
  55. except (ImportError, AttributeError) as e:
  56. from pickle import UnpicklingError
  57. raise UnpicklingError from e
  58. def pickle_function(func):
  59. mod_name = getattr(func, "__module__", None)
  60. qname = getattr(func, "__qualname__", None)
  61. self_ = getattr(func, "__self__", None)
  62. try:
  63. test = unpickle_function(mod_name, qname, self_)
  64. except pickle.UnpicklingError:
  65. test = None
  66. if test is not func:
  67. raise pickle.PicklingError(
  68. "Can't pickle {}: it's not the same object as {}".format(func, test)
  69. )
  70. return unpickle_function, (mod_name, qname, self_)
  71. def pickle_state(state):
  72. return _uarray._BackendState._unpickle, state._pickle()
  73. def pickle_set_backend_context(ctx):
  74. return _SetBackendContext, ctx._pickle()
  75. def pickle_skip_backend_context(ctx):
  76. return _SkipBackendContext, ctx._pickle()
  77. copyreg.pickle(_Function, pickle_function)
  78. copyreg.pickle(_uarray._BackendState, pickle_state)
  79. copyreg.pickle(_SetBackendContext, pickle_set_backend_context)
  80. copyreg.pickle(_SkipBackendContext, pickle_skip_backend_context)
  81. def get_state():
  82. """
  83. Returns an opaque object containing the current state of all the backends.
  84. Can be used for synchronization between threads/processes.
  85. See Also
  86. --------
  87. set_state
  88. Sets the state returned by this function.
  89. """
  90. return _uarray.get_state()
  91. @contextlib.contextmanager
  92. def reset_state():
  93. """
  94. Returns a context manager that resets all state once exited.
  95. See Also
  96. --------
  97. set_state
  98. Context manager that sets the backend state.
  99. get_state
  100. Gets a state to be set by this context manager.
  101. """
  102. with set_state(get_state()):
  103. yield
  104. @contextlib.contextmanager
  105. def set_state(state):
  106. """
  107. A context manager that sets the state of the backends to one returned by :obj:`get_state`.
  108. See Also
  109. --------
  110. get_state
  111. Gets a state to be set by this context manager.
  112. """
  113. old_state = get_state()
  114. _uarray.set_state(state)
  115. try:
  116. yield
  117. finally:
  118. _uarray.set_state(old_state, True)
  119. def create_multimethod(*args, **kwargs):
  120. """
  121. Creates a decorator for generating multimethods.
  122. This function creates a decorator that can be used with an argument
  123. extractor in order to generate a multimethod. Other than for the
  124. argument extractor, all arguments are passed on to
  125. :obj:`generate_multimethod`.
  126. See Also
  127. --------
  128. generate_multimethod
  129. Generates a multimethod.
  130. """
  131. def wrapper(a):
  132. return generate_multimethod(a, *args, **kwargs)
  133. return wrapper
  134. def generate_multimethod(
  135. argument_extractor: ArgumentExtractorType,
  136. argument_replacer: ArgumentReplacerType,
  137. domain: str,
  138. default: typing.Optional[typing.Callable] = None,
  139. ):
  140. """
  141. Generates a multimethod.
  142. Parameters
  143. ----------
  144. argument_extractor : ArgumentExtractorType
  145. A callable which extracts the dispatchable arguments. Extracted arguments
  146. should be marked by the :obj:`Dispatchable` class. It has the same signature
  147. as the desired multimethod.
  148. argument_replacer : ArgumentReplacerType
  149. A callable with the signature (args, kwargs, dispatchables), which should also
  150. return an (args, kwargs) pair with the dispatchables replaced inside the args/kwargs.
  151. domain : str
  152. A string value indicating the domain of this multimethod.
  153. default: Optional[Callable], optional
  154. The default implementation of this multimethod, where ``None`` (the default) specifies
  155. there is no default implementation.
  156. Examples
  157. --------
  158. In this example, ``a`` is to be dispatched over, so we return it, while marking it as an ``int``.
  159. The trailing comma is needed because the args have to be returned as an iterable.
  160. >>> def override_me(a, b):
  161. ... return Dispatchable(a, int),
  162. Next, we define the argument replacer that replaces the dispatchables inside args/kwargs with the
  163. supplied ones.
  164. >>> def override_replacer(args, kwargs, dispatchables):
  165. ... return (dispatchables[0], args[1]), {}
  166. Next, we define the multimethod.
  167. >>> overridden_me = generate_multimethod(
  168. ... override_me, override_replacer, "ua_examples"
  169. ... )
  170. Notice that there's no default implementation, unless you supply one.
  171. >>> overridden_me(1, "a")
  172. Traceback (most recent call last):
  173. ...
  174. uarray.BackendNotImplementedError: ...
  175. >>> overridden_me2 = generate_multimethod(
  176. ... override_me, override_replacer, "ua_examples", default=lambda x, y: (x, y)
  177. ... )
  178. >>> overridden_me2(1, "a")
  179. (1, 'a')
  180. See Also
  181. --------
  182. uarray
  183. See the module documentation for how to override the method by creating backends.
  184. """
  185. kw_defaults, arg_defaults, opts = get_defaults(argument_extractor)
  186. ua_func = _Function(
  187. argument_extractor,
  188. argument_replacer,
  189. domain,
  190. arg_defaults,
  191. kw_defaults,
  192. default,
  193. )
  194. return functools.update_wrapper(ua_func, argument_extractor)
  195. def set_backend(backend, coerce=False, only=False):
  196. """
  197. A context manager that sets the preferred backend.
  198. Parameters
  199. ----------
  200. backend
  201. The backend to set.
  202. coerce
  203. Whether or not to coerce to a specific backend's types. Implies ``only``.
  204. only
  205. Whether or not this should be the last backend to try.
  206. See Also
  207. --------
  208. skip_backend: A context manager that allows skipping of backends.
  209. set_global_backend: Set a single, global backend for a domain.
  210. """
  211. try:
  212. return backend.__ua_cache__["set", coerce, only]
  213. except AttributeError:
  214. backend.__ua_cache__ = {}
  215. except KeyError:
  216. pass
  217. ctx = _SetBackendContext(backend, coerce, only)
  218. backend.__ua_cache__["set", coerce, only] = ctx
  219. return ctx
  220. def skip_backend(backend):
  221. """
  222. A context manager that allows one to skip a given backend from processing
  223. entirely. This allows one to use another backend's code in a library that
  224. is also a consumer of the same backend.
  225. Parameters
  226. ----------
  227. backend
  228. The backend to skip.
  229. See Also
  230. --------
  231. set_backend: A context manager that allows setting of backends.
  232. set_global_backend: Set a single, global backend for a domain.
  233. """
  234. try:
  235. return backend.__ua_cache__["skip"]
  236. except AttributeError:
  237. backend.__ua_cache__ = {}
  238. except KeyError:
  239. pass
  240. ctx = _SkipBackendContext(backend)
  241. backend.__ua_cache__["skip"] = ctx
  242. return ctx
  243. def get_defaults(f):
  244. sig = inspect.signature(f)
  245. kw_defaults = {}
  246. arg_defaults = []
  247. opts = set()
  248. for k, v in sig.parameters.items():
  249. if v.default is not inspect.Parameter.empty:
  250. kw_defaults[k] = v.default
  251. if v.kind in (
  252. inspect.Parameter.POSITIONAL_ONLY,
  253. inspect.Parameter.POSITIONAL_OR_KEYWORD,
  254. ):
  255. arg_defaults.append(v.default)
  256. opts.add(k)
  257. return kw_defaults, tuple(arg_defaults), opts
  258. def set_global_backend(backend, coerce=False, only=False, *, try_last=False):
  259. """
  260. This utility method replaces the default backend for permanent use. It
  261. will be tried in the list of backends automatically, unless the
  262. ``only`` flag is set on a backend. This will be the first tried
  263. backend outside the :obj:`set_backend` context manager.
  264. Note that this method is not thread-safe.
  265. .. warning::
  266. We caution library authors against using this function in
  267. their code. We do *not* support this use-case. This function
  268. is meant to be used only by users themselves, or by a reference
  269. implementation, if one exists.
  270. Parameters
  271. ----------
  272. backend
  273. The backend to register.
  274. coerce : bool
  275. Whether to coerce input types when trying this backend.
  276. only : bool
  277. If ``True``, no more backends will be tried if this fails.
  278. Implied by ``coerce=True``.
  279. try_last : bool
  280. If ``True``, the global backend is tried after registered backends.
  281. See Also
  282. --------
  283. set_backend: A context manager that allows setting of backends.
  284. skip_backend: A context manager that allows skipping of backends.
  285. """
  286. _uarray.set_global_backend(backend, coerce, only, try_last)
  287. def register_backend(backend):
  288. """
  289. This utility method sets registers backend for permanent use. It
  290. will be tried in the list of backends automatically, unless the
  291. ``only`` flag is set on a backend.
  292. Note that this method is not thread-safe.
  293. Parameters
  294. ----------
  295. backend
  296. The backend to register.
  297. """
  298. _uarray.register_backend(backend)
  299. def clear_backends(domain, registered=True, globals=False):
  300. """
  301. This utility method clears registered backends.
  302. .. warning::
  303. We caution library authors against using this function in
  304. their code. We do *not* support this use-case. This function
  305. is meant to be used only by users themselves.
  306. .. warning::
  307. Do NOT use this method inside a multimethod call, or the
  308. program is likely to crash.
  309. Parameters
  310. ----------
  311. domain : Optional[str]
  312. The domain for which to de-register backends. ``None`` means
  313. de-register for all domains.
  314. registered : bool
  315. Whether or not to clear registered backends. See :obj:`register_backend`.
  316. globals : bool
  317. Whether or not to clear global backends. See :obj:`set_global_backend`.
  318. See Also
  319. --------
  320. register_backend : Register a backend globally.
  321. set_global_backend : Set a global backend.
  322. """
  323. _uarray.clear_backends(domain, registered, globals)
  324. class Dispatchable:
  325. """
  326. A utility class which marks an argument with a specific dispatch type.
  327. Attributes
  328. ----------
  329. value
  330. The value of the Dispatchable.
  331. type
  332. The type of the Dispatchable.
  333. Examples
  334. --------
  335. >>> x = Dispatchable(1, str)
  336. >>> x
  337. <Dispatchable: type=<class 'str'>, value=1>
  338. See Also
  339. --------
  340. all_of_type
  341. Marks all unmarked parameters of a function.
  342. mark_as
  343. Allows one to create a utility function to mark as a given type.
  344. """
  345. def __init__(self, value, dispatch_type, coercible=True):
  346. self.value = value
  347. self.type = dispatch_type
  348. self.coercible = coercible
  349. def __getitem__(self, index):
  350. return (self.type, self.value)[index]
  351. def __str__(self):
  352. return "<{0}: type={1!r}, value={2!r}>".format(
  353. type(self).__name__, self.type, self.value
  354. )
  355. __repr__ = __str__
  356. def mark_as(dispatch_type):
  357. """
  358. Creates a utility function to mark something as a specific type.
  359. Examples
  360. --------
  361. >>> mark_int = mark_as(int)
  362. >>> mark_int(1)
  363. <Dispatchable: type=<class 'int'>, value=1>
  364. """
  365. return functools.partial(Dispatchable, dispatch_type=dispatch_type)
  366. def all_of_type(arg_type):
  367. """
  368. Marks all unmarked arguments as a given type.
  369. Examples
  370. --------
  371. >>> @all_of_type(str)
  372. ... def f(a, b):
  373. ... return a, Dispatchable(b, int)
  374. >>> f('a', 1)
  375. (<Dispatchable: type=<class 'str'>, value='a'>, <Dispatchable: type=<class 'int'>, value=1>)
  376. """
  377. def outer(func):
  378. @functools.wraps(func)
  379. def inner(*args, **kwargs):
  380. extracted_args = func(*args, **kwargs)
  381. return tuple(
  382. Dispatchable(arg, arg_type)
  383. if not isinstance(arg, Dispatchable)
  384. else arg
  385. for arg in extracted_args
  386. )
  387. return inner
  388. return outer
  389. def wrap_single_convertor(convert_single):
  390. """
  391. Wraps a ``__ua_convert__`` defined for a single element to all elements.
  392. If any of them return ``NotImplemented``, the operation is assumed to be
  393. undefined.
  394. Accepts a signature of (value, type, coerce).
  395. """
  396. @functools.wraps(convert_single)
  397. def __ua_convert__(dispatchables, coerce):
  398. converted = []
  399. for d in dispatchables:
  400. c = convert_single(d.value, d.type, coerce and d.coercible)
  401. if c is NotImplemented:
  402. return NotImplemented
  403. converted.append(c)
  404. return converted
  405. return __ua_convert__
  406. def wrap_single_convertor_instance(convert_single):
  407. """
  408. Wraps a ``__ua_convert__`` defined for a single element to all elements.
  409. If any of them return ``NotImplemented``, the operation is assumed to be
  410. undefined.
  411. Accepts a signature of (value, type, coerce).
  412. """
  413. @functools.wraps(convert_single)
  414. def __ua_convert__(self, dispatchables, coerce):
  415. converted = []
  416. for d in dispatchables:
  417. c = convert_single(self, d.value, d.type, coerce and d.coercible)
  418. if c is NotImplemented:
  419. return NotImplemented
  420. converted.append(c)
  421. return converted
  422. return __ua_convert__
  423. def determine_backend(value, dispatch_type, *, domain, only=True, coerce=False):
  424. """Set the backend to the first active backend that supports ``value``
  425. This is useful for functions that call multimethods without any dispatchable
  426. arguments. You can use :func:`determine_backend` to ensure the same backend
  427. is used everywhere in a block of multimethod calls.
  428. Parameters
  429. ----------
  430. value
  431. The value being tested
  432. dispatch_type
  433. The dispatch type associated with ``value``, aka
  434. ":ref:`marking <MarkingGlossary>`".
  435. domain: string
  436. The domain to query for backends and set.
  437. coerce: bool
  438. Whether or not to allow coercion to the backend's types. Implies ``only``.
  439. only: bool
  440. Whether or not this should be the last backend to try.
  441. See Also
  442. --------
  443. set_backend: For when you know which backend to set
  444. Notes
  445. -----
  446. Support is determined by the ``__ua_convert__`` protocol. Backends not
  447. supporting the type must return ``NotImplemented`` from their
  448. ``__ua_convert__`` if they don't support input of that type.
  449. Examples
  450. --------
  451. Suppose we have two backends ``BackendA`` and ``BackendB`` each supporting
  452. different types, ``TypeA`` and ``TypeB``. Neither supporting the other type:
  453. >>> with ua.set_backend(ex.BackendA):
  454. ... ex.call_multimethod(ex.TypeB(), ex.TypeB())
  455. Traceback (most recent call last):
  456. ...
  457. uarray.BackendNotImplementedError: ...
  458. Now consider a multimethod that creates a new object of ``TypeA``, or
  459. ``TypeB`` depending on the active backend.
  460. >>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
  461. ... res = ex.creation_multimethod()
  462. ... ex.call_multimethod(res, ex.TypeA())
  463. Traceback (most recent call last):
  464. ...
  465. uarray.BackendNotImplementedError: ...
  466. ``res`` is an object of ``TypeB`` because ``BackendB`` is set in the
  467. innermost with statement. So, ``call_multimethod`` fails since the types
  468. don't match.
  469. Instead, we need to first find a backend suitable for all of our objects.
  470. >>> with ua.set_backend(ex.BackendA), ua.set_backend(ex.BackendB):
  471. ... x = ex.TypeA()
  472. ... with ua.determine_backend(x, "mark", domain="ua_examples"):
  473. ... res = ex.creation_multimethod()
  474. ... ex.call_multimethod(res, x)
  475. TypeA
  476. """
  477. dispatchables = (Dispatchable(value, dispatch_type, coerce),)
  478. backend = _uarray.determine_backend(domain, dispatchables, coerce)
  479. return set_backend(backend, coerce=coerce, only=only)
  480. def determine_backend_multi(
  481. dispatchables, *, domain, only=True, coerce=False, **kwargs
  482. ):
  483. """Set a backend supporting all ``dispatchables``
  484. This is useful for functions that call multimethods without any dispatchable
  485. arguments. You can use :func:`determine_backend_multi` to ensure the same
  486. backend is used everywhere in a block of multimethod calls involving
  487. multiple arrays.
  488. Parameters
  489. ----------
  490. dispatchables: Sequence[Union[uarray.Dispatchable, Any]]
  491. The dispatchables that must be supported
  492. domain: string
  493. The domain to query for backends and set.
  494. coerce: bool
  495. Whether or not to allow coercion to the backend's types. Implies ``only``.
  496. only: bool
  497. Whether or not this should be the last backend to try.
  498. dispatch_type: Optional[Any]
  499. The default dispatch type associated with ``dispatchables``, aka
  500. ":ref:`marking <MarkingGlossary>`".
  501. See Also
  502. --------
  503. determine_backend: For a single dispatch value
  504. set_backend: For when you know which backend to set
  505. Notes
  506. -----
  507. Support is determined by the ``__ua_convert__`` protocol. Backends not
  508. supporting the type must return ``NotImplemented`` from their
  509. ``__ua_convert__`` if they don't support input of that type.
  510. Examples
  511. --------
  512. :func:`determine_backend` allows the backend to be set from a single
  513. object. :func:`determine_backend_multi` allows multiple objects to be
  514. checked simultaneously for support in the backend. Suppose we have a
  515. ``BackendAB`` which supports ``TypeA`` and ``TypeB`` in the same call,
  516. and a ``BackendBC`` that doesn't support ``TypeA``.
  517. >>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
  518. ... a, b = ex.TypeA(), ex.TypeB()
  519. ... with ua.determine_backend_multi(
  520. ... [ua.Dispatchable(a, "mark"), ua.Dispatchable(b, "mark")],
  521. ... domain="ua_examples"
  522. ... ):
  523. ... res = ex.creation_multimethod()
  524. ... ex.call_multimethod(res, a, b)
  525. TypeA
  526. This won't call ``BackendBC`` because it doesn't support ``TypeA``.
  527. We can also use leave out the ``ua.Dispatchable`` if we specify the
  528. default ``dispatch_type`` for the ``dispatchables`` argument.
  529. >>> with ua.set_backend(ex.BackendAB), ua.set_backend(ex.BackendBC):
  530. ... a, b = ex.TypeA(), ex.TypeB()
  531. ... with ua.determine_backend_multi(
  532. ... [a, b], dispatch_type="mark", domain="ua_examples"
  533. ... ):
  534. ... res = ex.creation_multimethod()
  535. ... ex.call_multimethod(res, a, b)
  536. TypeA
  537. """
  538. if "dispatch_type" in kwargs:
  539. disp_type = kwargs.pop("dispatch_type")
  540. dispatchables = tuple(
  541. d if isinstance(d, Dispatchable) else Dispatchable(d, disp_type)
  542. for d in dispatchables
  543. )
  544. else:
  545. dispatchables = tuple(dispatchables)
  546. if not all(isinstance(d, Dispatchable) for d in dispatchables):
  547. raise TypeError("dispatchables must be instances of uarray.Dispatchable")
  548. if len(kwargs) != 0:
  549. raise TypeError("Received unexpected keyword arguments: {}".format(kwargs))
  550. backend = _uarray.determine_backend(domain, dispatchables, coerce)
  551. return set_backend(backend, coerce=coerce, only=only)