backends.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. """
  2. Code to support various backends in a plugin dispatch architecture.
  3. Create a Dispatcher
  4. -------------------
  5. To be a valid plugin, a package must register an entry_point
  6. of `networkx.plugins` with a key pointing to the handler.
  7. For example::
  8. entry_points={'networkx.plugins': 'sparse = networkx_plugin_sparse'}
  9. The plugin must create a Graph-like object which contains an attribute
  10. ``__networkx_plugin__`` with a value of the entry point name.
  11. Continuing the example above::
  12. class WrappedSparse:
  13. __networkx_plugin__ = "sparse"
  14. ...
  15. When a dispatchable NetworkX algorithm encounters a Graph-like object
  16. with a ``__networkx_plugin__`` attribute, it will look for the associated
  17. dispatch object in the entry_points, load it, and dispatch the work to it.
  18. Testing
  19. -------
  20. To assist in validating the backend algorithm implementations, if an
  21. environment variable ``NETWORKX_GRAPH_CONVERT`` is set to a registered
  22. plugin keys, the dispatch machinery will automatically convert regular
  23. networkx Graphs and DiGraphs to the backend equivalent by calling
  24. ``<backend dispatcher>.convert_from_nx(G, weight=weight, name=name)``.
  25. The converted object is then passed to the backend implementation of
  26. the algorithm. The result is then passed to
  27. ``<backend dispatcher>.convert_to_nx(result, name=name)`` to convert back
  28. to a form expected by the NetworkX tests.
  29. By defining ``convert_from_nx`` and ``convert_to_nx`` methods and setting
  30. the environment variable, NetworkX will automatically route tests on
  31. dispatchable algorithms to the backend, allowing the full networkx test
  32. suite to be run against the backend implementation.
  33. Example pytest invocation::
  34. NETWORKX_GRAPH_CONVERT=sparse pytest --pyargs networkx
  35. Dispatchable algorithms which are not implemented by the backend
  36. will cause a ``pytest.xfail()``, giving some indication that not all
  37. tests are working, while avoiding causing an explicit failure.
  38. A special ``on_start_tests(items)`` function may be defined by the backend.
  39. It will be called with the list of NetworkX tests discovered. Each item
  40. is a test object that can be marked as xfail if the backend does not support
  41. the test using `item.add_marker(pytest.mark.xfail(reason=...))`.
  42. """
  43. import functools
  44. import inspect
  45. import os
  46. import sys
  47. from importlib.metadata import entry_points
  48. from ..exception import NetworkXNotImplemented
  49. __all__ = ["_dispatch", "_mark_tests"]
  50. class PluginInfo:
  51. """Lazily loaded entry_points plugin information"""
  52. def __init__(self):
  53. self._items = None
  54. def __bool__(self):
  55. return len(self.items) > 0
  56. @property
  57. def items(self):
  58. if self._items is None:
  59. if sys.version_info < (3, 10):
  60. self._items = entry_points()["networkx.plugins"]
  61. else:
  62. self._items = entry_points(group="networkx.plugins")
  63. return self._items
  64. def __contains__(self, name):
  65. if sys.version_info < (3, 10):
  66. return len([ep for ep in self.items if ep.name == name]) > 0
  67. return name in self.items.names
  68. def __getitem__(self, name):
  69. if sys.version_info < (3, 10):
  70. return [ep for ep in self.items if ep.name == name][0]
  71. return self.items[name]
  72. plugins = PluginInfo()
  73. _registered_algorithms = {}
  74. def _register_algo(name, wrapped_func):
  75. if name in _registered_algorithms:
  76. raise KeyError(f"Algorithm already exists in dispatch registry: {name}")
  77. _registered_algorithms[name] = wrapped_func
  78. wrapped_func.dispatchname = name
  79. def _dispatch(func=None, *, name=None):
  80. """Dispatches to a backend algorithm
  81. when the first argument is a backend graph-like object.
  82. """
  83. # Allow any of the following decorator forms:
  84. # - @_dispatch
  85. # - @_dispatch()
  86. # - @_dispatch("override_name")
  87. # - @_dispatch(name="override_name")
  88. if func is None:
  89. if name is None:
  90. return _dispatch
  91. return functools.partial(_dispatch, name=name)
  92. if isinstance(func, str):
  93. return functools.partial(_dispatch, name=func)
  94. # If name not provided, use the name of the function
  95. if name is None:
  96. name = func.__name__
  97. @functools.wraps(func)
  98. def wrapper(*args, **kwds):
  99. if args:
  100. graph = args[0]
  101. else:
  102. try:
  103. graph = kwds["G"]
  104. except KeyError:
  105. raise TypeError(f"{name}() missing positional argument: 'G'") from None
  106. if hasattr(graph, "__networkx_plugin__") and plugins:
  107. plugin_name = graph.__networkx_plugin__
  108. if plugin_name in plugins:
  109. backend = plugins[plugin_name].load()
  110. if hasattr(backend, name):
  111. return getattr(backend, name).__call__(*args, **kwds)
  112. else:
  113. raise NetworkXNotImplemented(
  114. f"'{name}' not implemented by {plugin_name}"
  115. )
  116. return func(*args, **kwds)
  117. # Keep a handle to the original function to use when testing
  118. # the dispatch mechanism internally
  119. wrapper._orig_func = func
  120. _register_algo(name, wrapper)
  121. return wrapper
  122. def test_override_dispatch(func=None, *, name=None):
  123. """Auto-converts the first argument into the backend equivalent,
  124. causing the dispatching mechanism to trigger for every
  125. decorated algorithm."""
  126. if func is None:
  127. if name is None:
  128. return test_override_dispatch
  129. return functools.partial(test_override_dispatch, name=name)
  130. if isinstance(func, str):
  131. return functools.partial(test_override_dispatch, name=func)
  132. # If name not provided, use the name of the function
  133. if name is None:
  134. name = func.__name__
  135. sig = inspect.signature(func)
  136. @functools.wraps(func)
  137. def wrapper(*args, **kwds):
  138. backend = plugins[plugin_name].load()
  139. if not hasattr(backend, name):
  140. if plugin_name == "nx-loopback":
  141. raise NetworkXNotImplemented(
  142. f"'{name}' not found in {backend.__class__.__name__}"
  143. )
  144. pytest.xfail(f"'{name}' not implemented by {plugin_name}")
  145. bound = sig.bind(*args, **kwds)
  146. bound.apply_defaults()
  147. if args:
  148. graph, *args = args
  149. else:
  150. try:
  151. graph = kwds.pop("G")
  152. except KeyError:
  153. raise TypeError(f"{name}() missing positional argument: 'G'") from None
  154. # Convert graph into backend graph-like object
  155. # Include the weight label, if provided to the algorithm
  156. weight = None
  157. if "weight" in bound.arguments:
  158. weight = bound.arguments["weight"]
  159. elif "data" in bound.arguments and "default" in bound.arguments:
  160. # This case exists for several MultiGraph edge algorithms
  161. if isinstance(bound.arguments["data"], str):
  162. weight = bound.arguments["data"]
  163. elif bound.arguments["data"]:
  164. weight = "weight"
  165. graph = backend.convert_from_nx(graph, weight=weight, name=name)
  166. result = getattr(backend, name).__call__(graph, *args, **kwds)
  167. return backend.convert_to_nx(result, name=name)
  168. wrapper._orig_func = func
  169. _register_algo(name, wrapper)
  170. return wrapper
  171. # Check for auto-convert testing
  172. # This allows existing NetworkX tests to be run against a backend
  173. # implementation without any changes to the testing code. The only
  174. # required change is to set an environment variable prior to running
  175. # pytest.
  176. if os.environ.get("NETWORKX_GRAPH_CONVERT"):
  177. plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"]
  178. if not plugins:
  179. raise Exception("No registered networkx.plugins entry_points")
  180. if plugin_name not in plugins:
  181. raise Exception(
  182. f"No registered networkx.plugins entry_point named {plugin_name}"
  183. )
  184. try:
  185. import pytest
  186. except ImportError:
  187. raise ImportError(
  188. f"Missing pytest, which is required when using NETWORKX_GRAPH_CONVERT"
  189. )
  190. # Override `dispatch` for testing
  191. _dispatch = test_override_dispatch
  192. def _mark_tests(items):
  193. """Allow backend to mark tests (skip or xfail) if they aren't able to correctly handle them"""
  194. if os.environ.get("NETWORKX_GRAPH_CONVERT"):
  195. plugin_name = os.environ["NETWORKX_GRAPH_CONVERT"]
  196. backend = plugins[plugin_name].load()
  197. if hasattr(backend, "on_start_tests"):
  198. getattr(backend, "on_start_tests")(items)