dispatch_interface.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. # This file contains utilities for testing the dispatching feature
  2. # A full test of all dispatchable algorithms is performed by
  3. # modifying the pytest invocation and setting an environment variable
  4. # NETWORKX_GRAPH_CONVERT=nx-loopback pytest
  5. # This is comprehensive, but only tests the `test_override_dispatch`
  6. # function in networkx.classes.backends.
  7. # To test the `_dispatch` function directly, several tests scattered throughout
  8. # NetworkX have been augmented to test normal and dispatch mode.
  9. # Searching for `dispatch_interface` should locate the specific tests.
  10. import networkx as nx
  11. from networkx import DiGraph, Graph, MultiDiGraph, MultiGraph, PlanarEmbedding
  12. class LoopbackGraph(Graph):
  13. __networkx_plugin__ = "nx-loopback"
  14. class LoopbackDiGraph(DiGraph):
  15. __networkx_plugin__ = "nx-loopback"
  16. class LoopbackMultiGraph(MultiGraph):
  17. __networkx_plugin__ = "nx-loopback"
  18. class LoopbackMultiDiGraph(MultiDiGraph):
  19. __networkx_plugin__ = "nx-loopback"
  20. class LoopbackPlanarEmbedding(PlanarEmbedding):
  21. __networkx_plugin__ = "nx-loopback"
  22. def convert(graph):
  23. if isinstance(graph, PlanarEmbedding):
  24. return LoopbackPlanarEmbedding(graph)
  25. if isinstance(graph, MultiDiGraph):
  26. return LoopbackMultiDiGraph(graph)
  27. if isinstance(graph, MultiGraph):
  28. return LoopbackMultiGraph(graph)
  29. if isinstance(graph, DiGraph):
  30. return LoopbackDiGraph(graph)
  31. if isinstance(graph, Graph):
  32. return LoopbackGraph(graph)
  33. raise TypeError(f"Unsupported type of graph: {type(graph)}")
  34. class LoopbackDispatcher:
  35. non_toplevel = {
  36. "inter_community_edges": nx.community.quality.inter_community_edges,
  37. "is_tournament": nx.algorithms.tournament.is_tournament,
  38. "mutual_weight": nx.algorithms.structuralholes.mutual_weight,
  39. "score_sequence": nx.algorithms.tournament.score_sequence,
  40. "tournament_matrix": nx.algorithms.tournament.tournament_matrix,
  41. }
  42. def __getattr__(self, item):
  43. # Return the original, undecorated NetworkX algorithm
  44. if hasattr(nx, item):
  45. return getattr(nx, item)._orig_func
  46. if item in self.non_toplevel:
  47. return self.non_toplevel[item]._orig_func
  48. raise AttributeError(item)
  49. @staticmethod
  50. def convert_from_nx(graph, weight=None, *, name=None):
  51. return graph
  52. @staticmethod
  53. def convert_to_nx(obj, *, name=None):
  54. return obj
  55. @staticmethod
  56. def on_start_tests(items):
  57. # Verify that items can be xfailed
  58. for item in items:
  59. assert hasattr(item, "add_marker")
  60. dispatcher = LoopbackDispatcher()