test_d_separation.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. from itertools import combinations
  2. import pytest
  3. import networkx as nx
  4. def path_graph():
  5. """Return a path graph of length three."""
  6. G = nx.path_graph(3, create_using=nx.DiGraph)
  7. G.graph["name"] = "path"
  8. nx.freeze(G)
  9. return G
  10. def fork_graph():
  11. """Return a three node fork graph."""
  12. G = nx.DiGraph(name="fork")
  13. G.add_edges_from([(0, 1), (0, 2)])
  14. nx.freeze(G)
  15. return G
  16. def collider_graph():
  17. """Return a collider/v-structure graph with three nodes."""
  18. G = nx.DiGraph(name="collider")
  19. G.add_edges_from([(0, 2), (1, 2)])
  20. nx.freeze(G)
  21. return G
  22. def naive_bayes_graph():
  23. """Return a simply Naive Bayes PGM graph."""
  24. G = nx.DiGraph(name="naive_bayes")
  25. G.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)])
  26. nx.freeze(G)
  27. return G
  28. def asia_graph():
  29. """Return the 'Asia' PGM graph."""
  30. G = nx.DiGraph(name="asia")
  31. G.add_edges_from(
  32. [
  33. ("asia", "tuberculosis"),
  34. ("smoking", "cancer"),
  35. ("smoking", "bronchitis"),
  36. ("tuberculosis", "either"),
  37. ("cancer", "either"),
  38. ("either", "xray"),
  39. ("either", "dyspnea"),
  40. ("bronchitis", "dyspnea"),
  41. ]
  42. )
  43. nx.freeze(G)
  44. return G
  45. @pytest.fixture(name="path_graph")
  46. def path_graph_fixture():
  47. return path_graph()
  48. @pytest.fixture(name="fork_graph")
  49. def fork_graph_fixture():
  50. return fork_graph()
  51. @pytest.fixture(name="collider_graph")
  52. def collider_graph_fixture():
  53. return collider_graph()
  54. @pytest.fixture(name="naive_bayes_graph")
  55. def naive_bayes_graph_fixture():
  56. return naive_bayes_graph()
  57. @pytest.fixture(name="asia_graph")
  58. def asia_graph_fixture():
  59. return asia_graph()
  60. @pytest.mark.parametrize(
  61. "graph",
  62. [path_graph(), fork_graph(), collider_graph(), naive_bayes_graph(), asia_graph()],
  63. )
  64. def test_markov_condition(graph):
  65. """Test that the Markov condition holds for each PGM graph."""
  66. for node in graph.nodes:
  67. parents = set(graph.predecessors(node))
  68. non_descendants = graph.nodes - nx.descendants(graph, node) - {node} - parents
  69. assert nx.d_separated(graph, {node}, non_descendants, parents)
  70. def test_path_graph_dsep(path_graph):
  71. """Example-based test of d-separation for path_graph."""
  72. assert nx.d_separated(path_graph, {0}, {2}, {1})
  73. assert not nx.d_separated(path_graph, {0}, {2}, {})
  74. def test_fork_graph_dsep(fork_graph):
  75. """Example-based test of d-separation for fork_graph."""
  76. assert nx.d_separated(fork_graph, {1}, {2}, {0})
  77. assert not nx.d_separated(fork_graph, {1}, {2}, {})
  78. def test_collider_graph_dsep(collider_graph):
  79. """Example-based test of d-separation for collider_graph."""
  80. assert nx.d_separated(collider_graph, {0}, {1}, {})
  81. assert not nx.d_separated(collider_graph, {0}, {1}, {2})
  82. def test_naive_bayes_dsep(naive_bayes_graph):
  83. """Example-based test of d-separation for naive_bayes_graph."""
  84. for u, v in combinations(range(1, 5), 2):
  85. assert nx.d_separated(naive_bayes_graph, {u}, {v}, {0})
  86. assert not nx.d_separated(naive_bayes_graph, {u}, {v}, {})
  87. def test_asia_graph_dsep(asia_graph):
  88. """Example-based test of d-separation for asia_graph."""
  89. assert nx.d_separated(
  90. asia_graph, {"asia", "smoking"}, {"dyspnea", "xray"}, {"bronchitis", "either"}
  91. )
  92. assert nx.d_separated(
  93. asia_graph, {"tuberculosis", "cancer"}, {"bronchitis"}, {"smoking", "xray"}
  94. )
  95. def test_undirected_graphs_are_not_supported():
  96. """
  97. Test that undirected graphs are not supported.
  98. d-separation and its related algorithms do not apply in
  99. the case of undirected graphs.
  100. """
  101. g = nx.path_graph(3, nx.Graph)
  102. with pytest.raises(nx.NetworkXNotImplemented):
  103. nx.d_separated(g, {0}, {1}, {2})
  104. with pytest.raises(nx.NetworkXNotImplemented):
  105. nx.is_minimal_d_separator(g, {0}, {1}, {2})
  106. with pytest.raises(nx.NetworkXNotImplemented):
  107. nx.minimal_d_separator(g, {0}, {1})
  108. def test_cyclic_graphs_raise_error():
  109. """
  110. Test that cycle graphs should cause erroring.
  111. This is because PGMs assume a directed acyclic graph.
  112. """
  113. g = nx.cycle_graph(3, nx.DiGraph)
  114. with pytest.raises(nx.NetworkXError):
  115. nx.d_separated(g, {0}, {1}, {2})
  116. with pytest.raises(nx.NetworkXError):
  117. nx.minimal_d_separator(g, {0}, {1})
  118. with pytest.raises(nx.NetworkXError):
  119. nx.is_minimal_d_separator(g, {0}, {1}, {2})
  120. def test_invalid_nodes_raise_error(asia_graph):
  121. """
  122. Test that graphs that have invalid nodes passed in raise errors.
  123. """
  124. with pytest.raises(nx.NodeNotFound):
  125. nx.d_separated(asia_graph, {0}, {1}, {2})
  126. with pytest.raises(nx.NodeNotFound):
  127. nx.is_minimal_d_separator(asia_graph, 0, 1, {2})
  128. with pytest.raises(nx.NodeNotFound):
  129. nx.minimal_d_separator(asia_graph, 0, 1)
  130. def test_minimal_d_separator():
  131. # Case 1:
  132. # create a graph A -> B <- C
  133. # B -> D -> E;
  134. # B -> F;
  135. # G -> E;
  136. edge_list = [("A", "B"), ("C", "B"), ("B", "D"), ("D", "E"), ("B", "F"), ("G", "E")]
  137. G = nx.DiGraph(edge_list)
  138. assert not nx.d_separated(G, {"B"}, {"E"}, set())
  139. # minimal set of the corresponding graph
  140. # for B and E should be (D,)
  141. Zmin = nx.minimal_d_separator(G, "B", "E")
  142. # the minimal separating set should pass the test for minimality
  143. assert nx.is_minimal_d_separator(G, "B", "E", Zmin)
  144. assert Zmin == {"D"}
  145. # Case 2:
  146. # create a graph A -> B -> C
  147. # B -> D -> C;
  148. edge_list = [("A", "B"), ("B", "C"), ("B", "D"), ("D", "C")]
  149. G = nx.DiGraph(edge_list)
  150. assert not nx.d_separated(G, {"A"}, {"C"}, set())
  151. Zmin = nx.minimal_d_separator(G, "A", "C")
  152. # the minimal separating set should pass the test for minimality
  153. assert nx.is_minimal_d_separator(G, "A", "C", Zmin)
  154. assert Zmin == {"B"}