test_layout.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. """Unit tests for layout functions."""
  2. import pytest
  3. import networkx as nx
  4. np = pytest.importorskip("numpy")
  5. pytest.importorskip("scipy")
  6. class TestLayout:
  7. @classmethod
  8. def setup_class(cls):
  9. cls.Gi = nx.grid_2d_graph(5, 5)
  10. cls.Gs = nx.Graph()
  11. nx.add_path(cls.Gs, "abcdef")
  12. cls.bigG = nx.grid_2d_graph(25, 25) # > 500 nodes for sparse
  13. def test_spring_fixed_without_pos(self):
  14. G = nx.path_graph(4)
  15. pytest.raises(ValueError, nx.spring_layout, G, fixed=[0])
  16. pos = {0: (1, 1), 2: (0, 0)}
  17. pytest.raises(ValueError, nx.spring_layout, G, fixed=[0, 1], pos=pos)
  18. nx.spring_layout(G, fixed=[0, 2], pos=pos) # No ValueError
  19. def test_spring_init_pos(self):
  20. # Tests GH #2448
  21. import math
  22. G = nx.Graph()
  23. G.add_edges_from([(0, 1), (1, 2), (2, 0), (2, 3)])
  24. init_pos = {0: (0.0, 0.0)}
  25. fixed_pos = [0]
  26. pos = nx.fruchterman_reingold_layout(G, pos=init_pos, fixed=fixed_pos)
  27. has_nan = any(math.isnan(c) for coords in pos.values() for c in coords)
  28. assert not has_nan, "values should not be nan"
  29. def test_smoke_empty_graph(self):
  30. G = []
  31. nx.random_layout(G)
  32. nx.circular_layout(G)
  33. nx.planar_layout(G)
  34. nx.spring_layout(G)
  35. nx.fruchterman_reingold_layout(G)
  36. nx.spectral_layout(G)
  37. nx.shell_layout(G)
  38. nx.bipartite_layout(G, G)
  39. nx.spiral_layout(G)
  40. nx.multipartite_layout(G)
  41. nx.kamada_kawai_layout(G)
  42. def test_smoke_int(self):
  43. G = self.Gi
  44. nx.random_layout(G)
  45. nx.circular_layout(G)
  46. nx.planar_layout(G)
  47. nx.spring_layout(G)
  48. nx.fruchterman_reingold_layout(G)
  49. nx.fruchterman_reingold_layout(self.bigG)
  50. nx.spectral_layout(G)
  51. nx.spectral_layout(G.to_directed())
  52. nx.spectral_layout(self.bigG)
  53. nx.spectral_layout(self.bigG.to_directed())
  54. nx.shell_layout(G)
  55. nx.spiral_layout(G)
  56. nx.kamada_kawai_layout(G)
  57. nx.kamada_kawai_layout(G, dim=1)
  58. nx.kamada_kawai_layout(G, dim=3)
  59. nx.arf_layout(G)
  60. def test_smoke_string(self):
  61. G = self.Gs
  62. nx.random_layout(G)
  63. nx.circular_layout(G)
  64. nx.planar_layout(G)
  65. nx.spring_layout(G)
  66. nx.fruchterman_reingold_layout(G)
  67. nx.spectral_layout(G)
  68. nx.shell_layout(G)
  69. nx.spiral_layout(G)
  70. nx.kamada_kawai_layout(G)
  71. nx.kamada_kawai_layout(G, dim=1)
  72. nx.kamada_kawai_layout(G, dim=3)
  73. nx.arf_layout(G)
  74. def check_scale_and_center(self, pos, scale, center):
  75. center = np.array(center)
  76. low = center - scale
  77. hi = center + scale
  78. vpos = np.array(list(pos.values()))
  79. length = vpos.max(0) - vpos.min(0)
  80. assert (length <= 2 * scale).all()
  81. assert (vpos >= low).all()
  82. assert (vpos <= hi).all()
  83. def test_scale_and_center_arg(self):
  84. sc = self.check_scale_and_center
  85. c = (4, 5)
  86. G = nx.complete_graph(9)
  87. G.add_node(9)
  88. sc(nx.random_layout(G, center=c), scale=0.5, center=(4.5, 5.5))
  89. # rest can have 2*scale length: [-scale, scale]
  90. sc(nx.spring_layout(G, scale=2, center=c), scale=2, center=c)
  91. sc(nx.spectral_layout(G, scale=2, center=c), scale=2, center=c)
  92. sc(nx.circular_layout(G, scale=2, center=c), scale=2, center=c)
  93. sc(nx.shell_layout(G, scale=2, center=c), scale=2, center=c)
  94. sc(nx.spiral_layout(G, scale=2, center=c), scale=2, center=c)
  95. sc(nx.kamada_kawai_layout(G, scale=2, center=c), scale=2, center=c)
  96. c = (2, 3, 5)
  97. sc(nx.kamada_kawai_layout(G, dim=3, scale=2, center=c), scale=2, center=c)
  98. def test_planar_layout_non_planar_input(self):
  99. G = nx.complete_graph(9)
  100. pytest.raises(nx.NetworkXException, nx.planar_layout, G)
  101. def test_smoke_planar_layout_embedding_input(self):
  102. embedding = nx.PlanarEmbedding()
  103. embedding.set_data({0: [1, 2], 1: [0, 2], 2: [0, 1]})
  104. nx.planar_layout(embedding)
  105. def test_default_scale_and_center(self):
  106. sc = self.check_scale_and_center
  107. c = (0, 0)
  108. G = nx.complete_graph(9)
  109. G.add_node(9)
  110. sc(nx.random_layout(G), scale=0.5, center=(0.5, 0.5))
  111. sc(nx.spring_layout(G), scale=1, center=c)
  112. sc(nx.spectral_layout(G), scale=1, center=c)
  113. sc(nx.circular_layout(G), scale=1, center=c)
  114. sc(nx.shell_layout(G), scale=1, center=c)
  115. sc(nx.spiral_layout(G), scale=1, center=c)
  116. sc(nx.kamada_kawai_layout(G), scale=1, center=c)
  117. c = (0, 0, 0)
  118. sc(nx.kamada_kawai_layout(G, dim=3), scale=1, center=c)
  119. def test_circular_planar_and_shell_dim_error(self):
  120. G = nx.path_graph(4)
  121. pytest.raises(ValueError, nx.circular_layout, G, dim=1)
  122. pytest.raises(ValueError, nx.shell_layout, G, dim=1)
  123. pytest.raises(ValueError, nx.shell_layout, G, dim=3)
  124. pytest.raises(ValueError, nx.planar_layout, G, dim=1)
  125. pytest.raises(ValueError, nx.planar_layout, G, dim=3)
  126. def test_adjacency_interface_numpy(self):
  127. A = nx.to_numpy_array(self.Gs)
  128. pos = nx.drawing.layout._fruchterman_reingold(A)
  129. assert pos.shape == (6, 2)
  130. pos = nx.drawing.layout._fruchterman_reingold(A, dim=3)
  131. assert pos.shape == (6, 3)
  132. pos = nx.drawing.layout._sparse_fruchterman_reingold(A)
  133. assert pos.shape == (6, 2)
  134. def test_adjacency_interface_scipy(self):
  135. A = nx.to_scipy_sparse_array(self.Gs, dtype="d")
  136. pos = nx.drawing.layout._sparse_fruchterman_reingold(A)
  137. assert pos.shape == (6, 2)
  138. pos = nx.drawing.layout._sparse_spectral(A)
  139. assert pos.shape == (6, 2)
  140. pos = nx.drawing.layout._sparse_fruchterman_reingold(A, dim=3)
  141. assert pos.shape == (6, 3)
  142. def test_single_nodes(self):
  143. G = nx.path_graph(1)
  144. vpos = nx.shell_layout(G)
  145. assert not vpos[0].any()
  146. G = nx.path_graph(4)
  147. vpos = nx.shell_layout(G, [[0], [1, 2], [3]])
  148. assert not vpos[0].any()
  149. assert vpos[3].any() # ensure node 3 not at origin (#3188)
  150. assert np.linalg.norm(vpos[3]) <= 1 # ensure node 3 fits (#3753)
  151. vpos = nx.shell_layout(G, [[0], [1, 2], [3]], rotate=0)
  152. assert np.linalg.norm(vpos[3]) <= 1 # ensure node 3 fits (#3753)
  153. def test_smoke_initial_pos_fruchterman_reingold(self):
  154. pos = nx.circular_layout(self.Gi)
  155. npos = nx.fruchterman_reingold_layout(self.Gi, pos=pos)
  156. def test_smoke_initial_pos_arf(self):
  157. pos = nx.circular_layout(self.Gi)
  158. npos = nx.arf_layout(self.Gi, pos=pos)
  159. def test_fixed_node_fruchterman_reingold(self):
  160. # Dense version (numpy based)
  161. pos = nx.circular_layout(self.Gi)
  162. npos = nx.spring_layout(self.Gi, pos=pos, fixed=[(0, 0)])
  163. assert tuple(pos[(0, 0)]) == tuple(npos[(0, 0)])
  164. # Sparse version (scipy based)
  165. pos = nx.circular_layout(self.bigG)
  166. npos = nx.spring_layout(self.bigG, pos=pos, fixed=[(0, 0)])
  167. for axis in range(2):
  168. assert pos[(0, 0)][axis] == pytest.approx(npos[(0, 0)][axis], abs=1e-7)
  169. def test_center_parameter(self):
  170. G = nx.path_graph(1)
  171. nx.random_layout(G, center=(1, 1))
  172. vpos = nx.circular_layout(G, center=(1, 1))
  173. assert tuple(vpos[0]) == (1, 1)
  174. vpos = nx.planar_layout(G, center=(1, 1))
  175. assert tuple(vpos[0]) == (1, 1)
  176. vpos = nx.spring_layout(G, center=(1, 1))
  177. assert tuple(vpos[0]) == (1, 1)
  178. vpos = nx.fruchterman_reingold_layout(G, center=(1, 1))
  179. assert tuple(vpos[0]) == (1, 1)
  180. vpos = nx.spectral_layout(G, center=(1, 1))
  181. assert tuple(vpos[0]) == (1, 1)
  182. vpos = nx.shell_layout(G, center=(1, 1))
  183. assert tuple(vpos[0]) == (1, 1)
  184. vpos = nx.spiral_layout(G, center=(1, 1))
  185. assert tuple(vpos[0]) == (1, 1)
  186. def test_center_wrong_dimensions(self):
  187. G = nx.path_graph(1)
  188. assert id(nx.spring_layout) == id(nx.fruchterman_reingold_layout)
  189. pytest.raises(ValueError, nx.random_layout, G, center=(1, 1, 1))
  190. pytest.raises(ValueError, nx.circular_layout, G, center=(1, 1, 1))
  191. pytest.raises(ValueError, nx.planar_layout, G, center=(1, 1, 1))
  192. pytest.raises(ValueError, nx.spring_layout, G, center=(1, 1, 1))
  193. pytest.raises(ValueError, nx.spring_layout, G, dim=3, center=(1, 1))
  194. pytest.raises(ValueError, nx.spectral_layout, G, center=(1, 1, 1))
  195. pytest.raises(ValueError, nx.spectral_layout, G, dim=3, center=(1, 1))
  196. pytest.raises(ValueError, nx.shell_layout, G, center=(1, 1, 1))
  197. pytest.raises(ValueError, nx.spiral_layout, G, center=(1, 1, 1))
  198. pytest.raises(ValueError, nx.kamada_kawai_layout, G, center=(1, 1, 1))
  199. def test_empty_graph(self):
  200. G = nx.empty_graph()
  201. vpos = nx.random_layout(G, center=(1, 1))
  202. assert vpos == {}
  203. vpos = nx.circular_layout(G, center=(1, 1))
  204. assert vpos == {}
  205. vpos = nx.planar_layout(G, center=(1, 1))
  206. assert vpos == {}
  207. vpos = nx.bipartite_layout(G, G)
  208. assert vpos == {}
  209. vpos = nx.spring_layout(G, center=(1, 1))
  210. assert vpos == {}
  211. vpos = nx.fruchterman_reingold_layout(G, center=(1, 1))
  212. assert vpos == {}
  213. vpos = nx.spectral_layout(G, center=(1, 1))
  214. assert vpos == {}
  215. vpos = nx.shell_layout(G, center=(1, 1))
  216. assert vpos == {}
  217. vpos = nx.spiral_layout(G, center=(1, 1))
  218. assert vpos == {}
  219. vpos = nx.multipartite_layout(G, center=(1, 1))
  220. assert vpos == {}
  221. vpos = nx.kamada_kawai_layout(G, center=(1, 1))
  222. assert vpos == {}
  223. vpos = nx.arf_layout(G)
  224. assert vpos == {}
  225. def test_bipartite_layout(self):
  226. G = nx.complete_bipartite_graph(3, 5)
  227. top, bottom = nx.bipartite.sets(G)
  228. vpos = nx.bipartite_layout(G, top)
  229. assert len(vpos) == len(G)
  230. top_x = vpos[list(top)[0]][0]
  231. bottom_x = vpos[list(bottom)[0]][0]
  232. for node in top:
  233. assert vpos[node][0] == top_x
  234. for node in bottom:
  235. assert vpos[node][0] == bottom_x
  236. vpos = nx.bipartite_layout(
  237. G, top, align="horizontal", center=(2, 2), scale=2, aspect_ratio=1
  238. )
  239. assert len(vpos) == len(G)
  240. top_y = vpos[list(top)[0]][1]
  241. bottom_y = vpos[list(bottom)[0]][1]
  242. for node in top:
  243. assert vpos[node][1] == top_y
  244. for node in bottom:
  245. assert vpos[node][1] == bottom_y
  246. pytest.raises(ValueError, nx.bipartite_layout, G, top, align="foo")
  247. def test_multipartite_layout(self):
  248. sizes = (0, 5, 7, 2, 8)
  249. G = nx.complete_multipartite_graph(*sizes)
  250. vpos = nx.multipartite_layout(G)
  251. assert len(vpos) == len(G)
  252. start = 0
  253. for n in sizes:
  254. end = start + n
  255. assert all(vpos[start][0] == vpos[i][0] for i in range(start + 1, end))
  256. start += n
  257. vpos = nx.multipartite_layout(G, align="horizontal", scale=2, center=(2, 2))
  258. assert len(vpos) == len(G)
  259. start = 0
  260. for n in sizes:
  261. end = start + n
  262. assert all(vpos[start][1] == vpos[i][1] for i in range(start + 1, end))
  263. start += n
  264. pytest.raises(ValueError, nx.multipartite_layout, G, align="foo")
  265. def test_kamada_kawai_costfn_1d(self):
  266. costfn = nx.drawing.layout._kamada_kawai_costfn
  267. pos = np.array([4.0, 7.0])
  268. invdist = 1 / np.array([[0.1, 2.0], [2.0, 0.3]])
  269. cost, grad = costfn(pos, np, invdist, meanweight=0, dim=1)
  270. assert cost == pytest.approx(((3 / 2.0 - 1) ** 2), abs=1e-7)
  271. assert grad[0] == pytest.approx((-0.5), abs=1e-7)
  272. assert grad[1] == pytest.approx(0.5, abs=1e-7)
  273. def check_kamada_kawai_costfn(self, pos, invdist, meanwt, dim):
  274. costfn = nx.drawing.layout._kamada_kawai_costfn
  275. cost, grad = costfn(pos.ravel(), np, invdist, meanweight=meanwt, dim=dim)
  276. expected_cost = 0.5 * meanwt * np.sum(np.sum(pos, axis=0) ** 2)
  277. for i in range(pos.shape[0]):
  278. for j in range(i + 1, pos.shape[0]):
  279. diff = np.linalg.norm(pos[i] - pos[j])
  280. expected_cost += (diff * invdist[i][j] - 1.0) ** 2
  281. assert cost == pytest.approx(expected_cost, abs=1e-7)
  282. dx = 1e-4
  283. for nd in range(pos.shape[0]):
  284. for dm in range(pos.shape[1]):
  285. idx = nd * pos.shape[1] + dm
  286. ps = pos.flatten()
  287. ps[idx] += dx
  288. cplus = costfn(ps, np, invdist, meanweight=meanwt, dim=pos.shape[1])[0]
  289. ps[idx] -= 2 * dx
  290. cminus = costfn(ps, np, invdist, meanweight=meanwt, dim=pos.shape[1])[0]
  291. assert grad[idx] == pytest.approx((cplus - cminus) / (2 * dx), abs=1e-5)
  292. def test_kamada_kawai_costfn(self):
  293. invdist = 1 / np.array([[0.1, 2.1, 1.7], [2.1, 0.2, 0.6], [1.7, 0.6, 0.3]])
  294. meanwt = 0.3
  295. # 2d
  296. pos = np.array([[1.3, -3.2], [2.7, -0.3], [5.1, 2.5]])
  297. self.check_kamada_kawai_costfn(pos, invdist, meanwt, 2)
  298. # 3d
  299. pos = np.array([[0.9, 8.6, -8.7], [-10, -0.5, -7.1], [9.1, -8.1, 1.6]])
  300. self.check_kamada_kawai_costfn(pos, invdist, meanwt, 3)
  301. def test_spiral_layout(self):
  302. G = self.Gs
  303. # a lower value of resolution should result in a more compact layout
  304. # intuitively, the total distance from the start and end nodes
  305. # via each node in between (transiting through each) will be less,
  306. # assuming rescaling does not occur on the computed node positions
  307. pos_standard = np.array(list(nx.spiral_layout(G, resolution=0.35).values()))
  308. pos_tighter = np.array(list(nx.spiral_layout(G, resolution=0.34).values()))
  309. distances = np.linalg.norm(pos_standard[:-1] - pos_standard[1:], axis=1)
  310. distances_tighter = np.linalg.norm(pos_tighter[:-1] - pos_tighter[1:], axis=1)
  311. assert sum(distances) > sum(distances_tighter)
  312. # return near-equidistant points after the first value if set to true
  313. pos_equidistant = np.array(list(nx.spiral_layout(G, equidistant=True).values()))
  314. distances_equidistant = np.linalg.norm(
  315. pos_equidistant[:-1] - pos_equidistant[1:], axis=1
  316. )
  317. assert np.allclose(
  318. distances_equidistant[1:], distances_equidistant[-1], atol=0.01
  319. )
  320. def test_spiral_layout_equidistant(self):
  321. G = nx.path_graph(10)
  322. pos = nx.spiral_layout(G, equidistant=True)
  323. # Extract individual node positions as an array
  324. p = np.array(list(pos.values()))
  325. # Elementwise-distance between node positions
  326. dist = np.linalg.norm(p[1:] - p[:-1], axis=1)
  327. assert np.allclose(np.diff(dist), 0, atol=1e-3)
  328. def test_rescale_layout_dict(self):
  329. G = nx.empty_graph()
  330. vpos = nx.random_layout(G, center=(1, 1))
  331. assert nx.rescale_layout_dict(vpos) == {}
  332. G = nx.empty_graph(2)
  333. vpos = {0: (0.0, 0.0), 1: (1.0, 1.0)}
  334. s_vpos = nx.rescale_layout_dict(vpos)
  335. assert np.linalg.norm([sum(x) for x in zip(*s_vpos.values())]) < 1e-6
  336. G = nx.empty_graph(3)
  337. vpos = {0: (0, 0), 1: (1, 1), 2: (0.5, 0.5)}
  338. s_vpos = nx.rescale_layout_dict(vpos)
  339. expectation = {
  340. 0: np.array((-1, -1)),
  341. 1: np.array((1, 1)),
  342. 2: np.array((0, 0)),
  343. }
  344. for k, v in expectation.items():
  345. assert (s_vpos[k] == v).all()
  346. s_vpos = nx.rescale_layout_dict(vpos, scale=2)
  347. expectation = {
  348. 0: np.array((-2, -2)),
  349. 1: np.array((2, 2)),
  350. 2: np.array((0, 0)),
  351. }
  352. for k, v in expectation.items():
  353. assert (s_vpos[k] == v).all()
  354. def test_arf_layout_partial_input_test(self):
  355. """
  356. Checks whether partial pos input still returns a proper position.
  357. """
  358. G = self.Gs
  359. node = nx.utils.arbitrary_element(G)
  360. pos = nx.circular_layout(G)
  361. del pos[node]
  362. pos = nx.arf_layout(G, pos=pos)
  363. assert len(pos) == len(G)
  364. def test_arf_layout_negative_a_check(self):
  365. """
  366. Checks input parameters correctly raises errors. For example, `a` should be larger than 1
  367. """
  368. G = self.Gs
  369. pytest.raises(ValueError, nx.arf_layout, G=G, a=-1)
  370. def test_multipartite_layout_nonnumeric_partition_labels():
  371. """See gh-5123."""
  372. G = nx.Graph()
  373. G.add_node(0, subset="s0")
  374. G.add_node(1, subset="s0")
  375. G.add_node(2, subset="s1")
  376. G.add_node(3, subset="s1")
  377. G.add_edges_from([(0, 2), (0, 3), (1, 2)])
  378. pos = nx.multipartite_layout(G)
  379. assert len(pos) == len(G)
  380. def test_multipartite_layout_layer_order():
  381. """Return the layers in sorted order if the layers of the multipartite
  382. graph are sortable. See gh-5691"""
  383. G = nx.Graph()
  384. for node, layer in zip(("a", "b", "c", "d", "e"), (2, 3, 1, 2, 4)):
  385. G.add_node(node, subset=layer)
  386. # Horizontal alignment, therefore y-coord determines layers
  387. pos = nx.multipartite_layout(G, align="horizontal")
  388. # Nodes "a" and "d" are in the same layer
  389. assert pos["a"][-1] == pos["d"][-1]
  390. # positions should be sorted according to layer
  391. assert pos["c"][-1] < pos["a"][-1] < pos["b"][-1] < pos["e"][-1]
  392. # Make sure that multipartite_layout still works when layers are not sortable
  393. G.nodes["a"]["subset"] = "layer_0" # Can't sort mixed strs/ints
  394. pos_nosort = nx.multipartite_layout(G) # smoke test: this should not raise
  395. assert pos_nosort.keys() == pos.keys()