wiener.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. """Functions related to the Wiener index of a graph."""
  2. from itertools import chain
  3. from .components import is_connected, is_strongly_connected
  4. from .shortest_paths import shortest_path_length as spl
  5. __all__ = ["wiener_index"]
  6. #: Rename the :func:`chain.from_iterable` function for the sake of
  7. #: brevity.
  8. chaini = chain.from_iterable
  9. def wiener_index(G, weight=None):
  10. """Returns the Wiener index of the given graph.
  11. The *Wiener index* of a graph is the sum of the shortest-path
  12. distances between each pair of reachable nodes. For pairs of nodes
  13. in undirected graphs, only one orientation of the pair is counted.
  14. Parameters
  15. ----------
  16. G : NetworkX graph
  17. weight : object
  18. The edge attribute to use as distance when computing
  19. shortest-path distances. This is passed directly to the
  20. :func:`networkx.shortest_path_length` function.
  21. Returns
  22. -------
  23. float
  24. The Wiener index of the graph `G`.
  25. Raises
  26. ------
  27. NetworkXError
  28. If the graph `G` is not connected.
  29. Notes
  30. -----
  31. If a pair of nodes is not reachable, the distance is assumed to be
  32. infinity. This means that for graphs that are not
  33. strongly-connected, this function returns ``inf``.
  34. The Wiener index is not usually defined for directed graphs, however
  35. this function uses the natural generalization of the Wiener index to
  36. directed graphs.
  37. Examples
  38. --------
  39. The Wiener index of the (unweighted) complete graph on *n* nodes
  40. equals the number of pairs of the *n* nodes, since each pair of
  41. nodes is at distance one::
  42. >>> n = 10
  43. >>> G = nx.complete_graph(n)
  44. >>> nx.wiener_index(G) == n * (n - 1) / 2
  45. True
  46. Graphs that are not strongly-connected have infinite Wiener index::
  47. >>> G = nx.empty_graph(2)
  48. >>> nx.wiener_index(G)
  49. inf
  50. """
  51. is_directed = G.is_directed()
  52. if (is_directed and not is_strongly_connected(G)) or (
  53. not is_directed and not is_connected(G)
  54. ):
  55. return float("inf")
  56. total = sum(chaini(p.values() for v, p in spl(G, weight=weight)))
  57. # Need to account for double counting pairs of nodes in undirected graphs.
  58. return total if is_directed else total / 2