misc.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491
  1. """
  2. Miscellaneous Helpers for NetworkX.
  3. These are not imported into the base networkx namespace but
  4. can be accessed, for example, as
  5. >>> import networkx
  6. >>> networkx.utils.make_list_of_ints({1, 2, 3})
  7. [1, 2, 3]
  8. >>> networkx.utils.arbitrary_element({5, 1, 7}) # doctest: +SKIP
  9. 1
  10. """
  11. import sys
  12. import uuid
  13. import warnings
  14. from collections import defaultdict, deque
  15. from collections.abc import Iterable, Iterator, Sized
  16. from itertools import chain, tee
  17. import networkx as nx
  18. __all__ = [
  19. "flatten",
  20. "make_list_of_ints",
  21. "dict_to_numpy_array",
  22. "arbitrary_element",
  23. "pairwise",
  24. "groups",
  25. "create_random_state",
  26. "create_py_random_state",
  27. "PythonRandomInterface",
  28. "nodes_equal",
  29. "edges_equal",
  30. "graphs_equal",
  31. ]
  32. # some cookbook stuff
  33. # used in deciding whether something is a bunch of nodes, edges, etc.
  34. # see G.add_nodes and others in Graph Class in networkx/base.py
  35. def flatten(obj, result=None):
  36. """Return flattened version of (possibly nested) iterable object."""
  37. if not isinstance(obj, (Iterable, Sized)) or isinstance(obj, str):
  38. return obj
  39. if result is None:
  40. result = []
  41. for item in obj:
  42. if not isinstance(item, (Iterable, Sized)) or isinstance(item, str):
  43. result.append(item)
  44. else:
  45. flatten(item, result)
  46. return tuple(result)
  47. def make_list_of_ints(sequence):
  48. """Return list of ints from sequence of integral numbers.
  49. All elements of the sequence must satisfy int(element) == element
  50. or a ValueError is raised. Sequence is iterated through once.
  51. If sequence is a list, the non-int values are replaced with ints.
  52. So, no new list is created
  53. """
  54. if not isinstance(sequence, list):
  55. result = []
  56. for i in sequence:
  57. errmsg = f"sequence is not all integers: {i}"
  58. try:
  59. ii = int(i)
  60. except ValueError:
  61. raise nx.NetworkXError(errmsg) from None
  62. if ii != i:
  63. raise nx.NetworkXError(errmsg)
  64. result.append(ii)
  65. return result
  66. # original sequence is a list... in-place conversion to ints
  67. for indx, i in enumerate(sequence):
  68. errmsg = f"sequence is not all integers: {i}"
  69. if isinstance(i, int):
  70. continue
  71. try:
  72. ii = int(i)
  73. except ValueError:
  74. raise nx.NetworkXError(errmsg) from None
  75. if ii != i:
  76. raise nx.NetworkXError(errmsg)
  77. sequence[indx] = ii
  78. return sequence
  79. def dict_to_numpy_array(d, mapping=None):
  80. """Convert a dictionary of dictionaries to a numpy array
  81. with optional mapping."""
  82. try:
  83. return _dict_to_numpy_array2(d, mapping)
  84. except (AttributeError, TypeError):
  85. # AttributeError is when no mapping was provided and v.keys() fails.
  86. # TypeError is when a mapping was provided and d[k1][k2] fails.
  87. return _dict_to_numpy_array1(d, mapping)
  88. def _dict_to_numpy_array2(d, mapping=None):
  89. """Convert a dictionary of dictionaries to a 2d numpy array
  90. with optional mapping.
  91. """
  92. import numpy as np
  93. if mapping is None:
  94. s = set(d.keys())
  95. for k, v in d.items():
  96. s.update(v.keys())
  97. mapping = dict(zip(s, range(len(s))))
  98. n = len(mapping)
  99. a = np.zeros((n, n))
  100. for k1, i in mapping.items():
  101. for k2, j in mapping.items():
  102. try:
  103. a[i, j] = d[k1][k2]
  104. except KeyError:
  105. pass
  106. return a
  107. def _dict_to_numpy_array1(d, mapping=None):
  108. """Convert a dictionary of numbers to a 1d numpy array with optional mapping."""
  109. import numpy as np
  110. if mapping is None:
  111. s = set(d.keys())
  112. mapping = dict(zip(s, range(len(s))))
  113. n = len(mapping)
  114. a = np.zeros(n)
  115. for k1, i in mapping.items():
  116. i = mapping[k1]
  117. a[i] = d[k1]
  118. return a
  119. def arbitrary_element(iterable):
  120. """Returns an arbitrary element of `iterable` without removing it.
  121. This is most useful for "peeking" at an arbitrary element of a set,
  122. but can be used for any list, dictionary, etc., as well.
  123. Parameters
  124. ----------
  125. iterable : `abc.collections.Iterable` instance
  126. Any object that implements ``__iter__``, e.g. set, dict, list, tuple,
  127. etc.
  128. Returns
  129. -------
  130. The object that results from ``next(iter(iterable))``
  131. Raises
  132. ------
  133. ValueError
  134. If `iterable` is an iterator (because the current implementation of
  135. this function would consume an element from the iterator).
  136. Examples
  137. --------
  138. Arbitrary elements from common Iterable objects:
  139. >>> nx.utils.arbitrary_element([1, 2, 3]) # list
  140. 1
  141. >>> nx.utils.arbitrary_element((1, 2, 3)) # tuple
  142. 1
  143. >>> nx.utils.arbitrary_element({1, 2, 3}) # set
  144. 1
  145. >>> d = {k: v for k, v in zip([1, 2, 3], [3, 2, 1])}
  146. >>> nx.utils.arbitrary_element(d) # dict_keys
  147. 1
  148. >>> nx.utils.arbitrary_element(d.values()) # dict values
  149. 3
  150. `str` is also an Iterable:
  151. >>> nx.utils.arbitrary_element("hello")
  152. 'h'
  153. :exc:`ValueError` is raised if `iterable` is an iterator:
  154. >>> iterator = iter([1, 2, 3]) # Iterator, *not* Iterable
  155. >>> nx.utils.arbitrary_element(iterator)
  156. Traceback (most recent call last):
  157. ...
  158. ValueError: cannot return an arbitrary item from an iterator
  159. Notes
  160. -----
  161. This function does not return a *random* element. If `iterable` is
  162. ordered, sequential calls will return the same value::
  163. >>> l = [1, 2, 3]
  164. >>> nx.utils.arbitrary_element(l)
  165. 1
  166. >>> nx.utils.arbitrary_element(l)
  167. 1
  168. """
  169. if isinstance(iterable, Iterator):
  170. raise ValueError("cannot return an arbitrary item from an iterator")
  171. # Another possible implementation is ``for x in iterable: return x``.
  172. return next(iter(iterable))
  173. # Recipe from the itertools documentation.
  174. def pairwise(iterable, cyclic=False):
  175. "s -> (s0, s1), (s1, s2), (s2, s3), ..."
  176. a, b = tee(iterable)
  177. first = next(b, None)
  178. if cyclic is True:
  179. return zip(a, chain(b, (first,)))
  180. return zip(a, b)
  181. def groups(many_to_one):
  182. """Converts a many-to-one mapping into a one-to-many mapping.
  183. `many_to_one` must be a dictionary whose keys and values are all
  184. :term:`hashable`.
  185. The return value is a dictionary mapping values from `many_to_one`
  186. to sets of keys from `many_to_one` that have that value.
  187. Examples
  188. --------
  189. >>> from networkx.utils import groups
  190. >>> many_to_one = {"a": 1, "b": 1, "c": 2, "d": 3, "e": 3}
  191. >>> groups(many_to_one) # doctest: +SKIP
  192. {1: {'a', 'b'}, 2: {'c'}, 3: {'e', 'd'}}
  193. """
  194. one_to_many = defaultdict(set)
  195. for v, k in many_to_one.items():
  196. one_to_many[k].add(v)
  197. return dict(one_to_many)
  198. def create_random_state(random_state=None):
  199. """Returns a numpy.random.RandomState or numpy.random.Generator instance
  200. depending on input.
  201. Parameters
  202. ----------
  203. random_state : int or NumPy RandomState or Generator instance, optional (default=None)
  204. If int, return a numpy.random.RandomState instance set with seed=int.
  205. if `numpy.random.RandomState` instance, return it.
  206. if `numpy.random.Generator` instance, return it.
  207. if None or numpy.random, return the global random number generator used
  208. by numpy.random.
  209. """
  210. import numpy as np
  211. if random_state is None or random_state is np.random:
  212. return np.random.mtrand._rand
  213. if isinstance(random_state, np.random.RandomState):
  214. return random_state
  215. if isinstance(random_state, int):
  216. return np.random.RandomState(random_state)
  217. if isinstance(random_state, np.random.Generator):
  218. return random_state
  219. msg = (
  220. f"{random_state} cannot be used to create a numpy.random.RandomState or\n"
  221. "numpy.random.Generator instance"
  222. )
  223. raise ValueError(msg)
  224. class PythonRandomInterface:
  225. def __init__(self, rng=None):
  226. try:
  227. import numpy as np
  228. except ImportError:
  229. msg = "numpy not found, only random.random available."
  230. warnings.warn(msg, ImportWarning)
  231. if rng is None:
  232. self._rng = np.random.mtrand._rand
  233. else:
  234. self._rng = rng
  235. def random(self):
  236. return self._rng.random()
  237. def uniform(self, a, b):
  238. return a + (b - a) * self._rng.random()
  239. def randrange(self, a, b=None):
  240. import numpy as np
  241. if isinstance(self._rng, np.random.Generator):
  242. return self._rng.integers(a, b)
  243. return self._rng.randint(a, b)
  244. # NOTE: the numpy implementations of `choice` don't support strings, so
  245. # this cannot be replaced with self._rng.choice
  246. def choice(self, seq):
  247. import numpy as np
  248. if isinstance(self._rng, np.random.Generator):
  249. idx = self._rng.integers(0, len(seq))
  250. else:
  251. idx = self._rng.randint(0, len(seq))
  252. return seq[idx]
  253. def gauss(self, mu, sigma):
  254. return self._rng.normal(mu, sigma)
  255. def shuffle(self, seq):
  256. return self._rng.shuffle(seq)
  257. # Some methods don't match API for numpy RandomState.
  258. # Commented out versions are not used by NetworkX
  259. def sample(self, seq, k):
  260. return self._rng.choice(list(seq), size=(k,), replace=False)
  261. def randint(self, a, b):
  262. import numpy as np
  263. if isinstance(self._rng, np.random.Generator):
  264. return self._rng.integers(a, b + 1)
  265. return self._rng.randint(a, b + 1)
  266. # exponential as expovariate with 1/argument,
  267. def expovariate(self, scale):
  268. return self._rng.exponential(1 / scale)
  269. # pareto as paretovariate with 1/argument,
  270. def paretovariate(self, shape):
  271. return self._rng.pareto(shape)
  272. # weibull as weibullvariate multiplied by beta,
  273. # def weibullvariate(self, alpha, beta):
  274. # return self._rng.weibull(alpha) * beta
  275. #
  276. # def triangular(self, low, high, mode):
  277. # return self._rng.triangular(low, mode, high)
  278. #
  279. # def choices(self, seq, weights=None, cum_weights=None, k=1):
  280. # return self._rng.choice(seq
  281. def create_py_random_state(random_state=None):
  282. """Returns a random.Random instance depending on input.
  283. Parameters
  284. ----------
  285. random_state : int or random number generator or None (default=None)
  286. If int, return a random.Random instance set with seed=int.
  287. if random.Random instance, return it.
  288. if None or the `random` package, return the global random number
  289. generator used by `random`.
  290. if np.random package, return the global numpy random number
  291. generator wrapped in a PythonRandomInterface class.
  292. if np.random.RandomState or np.random.Generator instance, return it
  293. wrapped in PythonRandomInterface
  294. if a PythonRandomInterface instance, return it
  295. """
  296. import random
  297. try:
  298. import numpy as np
  299. if random_state is np.random:
  300. return PythonRandomInterface(np.random.mtrand._rand)
  301. if isinstance(random_state, (np.random.RandomState, np.random.Generator)):
  302. return PythonRandomInterface(random_state)
  303. if isinstance(random_state, PythonRandomInterface):
  304. return random_state
  305. except ImportError:
  306. pass
  307. if random_state is None or random_state is random:
  308. return random._inst
  309. if isinstance(random_state, random.Random):
  310. return random_state
  311. if isinstance(random_state, int):
  312. return random.Random(random_state)
  313. msg = f"{random_state} cannot be used to generate a random.Random instance"
  314. raise ValueError(msg)
  315. def nodes_equal(nodes1, nodes2):
  316. """Check if nodes are equal.
  317. Equality here means equal as Python objects.
  318. Node data must match if included.
  319. The order of nodes is not relevant.
  320. Parameters
  321. ----------
  322. nodes1, nodes2 : iterables of nodes, or (node, datadict) tuples
  323. Returns
  324. -------
  325. bool
  326. True if nodes are equal, False otherwise.
  327. """
  328. nlist1 = list(nodes1)
  329. nlist2 = list(nodes2)
  330. try:
  331. d1 = dict(nlist1)
  332. d2 = dict(nlist2)
  333. except (ValueError, TypeError):
  334. d1 = dict.fromkeys(nlist1)
  335. d2 = dict.fromkeys(nlist2)
  336. return d1 == d2
  337. def edges_equal(edges1, edges2):
  338. """Check if edges are equal.
  339. Equality here means equal as Python objects.
  340. Edge data must match if included.
  341. The order of the edges is not relevant.
  342. Parameters
  343. ----------
  344. edges1, edges2 : iterables of with u, v nodes as
  345. edge tuples (u, v), or
  346. edge tuples with data dicts (u, v, d), or
  347. edge tuples with keys and data dicts (u, v, k, d)
  348. Returns
  349. -------
  350. bool
  351. True if edges are equal, False otherwise.
  352. """
  353. from collections import defaultdict
  354. d1 = defaultdict(dict)
  355. d2 = defaultdict(dict)
  356. c1 = 0
  357. for c1, e in enumerate(edges1):
  358. u, v = e[0], e[1]
  359. data = [e[2:]]
  360. if v in d1[u]:
  361. data = d1[u][v] + data
  362. d1[u][v] = data
  363. d1[v][u] = data
  364. c2 = 0
  365. for c2, e in enumerate(edges2):
  366. u, v = e[0], e[1]
  367. data = [e[2:]]
  368. if v in d2[u]:
  369. data = d2[u][v] + data
  370. d2[u][v] = data
  371. d2[v][u] = data
  372. if c1 != c2:
  373. return False
  374. # can check one direction because lengths are the same.
  375. for n, nbrdict in d1.items():
  376. for nbr, datalist in nbrdict.items():
  377. if n not in d2:
  378. return False
  379. if nbr not in d2[n]:
  380. return False
  381. d2datalist = d2[n][nbr]
  382. for data in datalist:
  383. if datalist.count(data) != d2datalist.count(data):
  384. return False
  385. return True
  386. def graphs_equal(graph1, graph2):
  387. """Check if graphs are equal.
  388. Equality here means equal as Python objects (not isomorphism).
  389. Node, edge and graph data must match.
  390. Parameters
  391. ----------
  392. graph1, graph2 : graph
  393. Returns
  394. -------
  395. bool
  396. True if graphs are equal, False otherwise.
  397. """
  398. return (
  399. graph1.adj == graph2.adj
  400. and graph1.nodes == graph2.nodes
  401. and graph1.graph == graph2.graph
  402. )