123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202 |
- from itertools import combinations
- import pytest
- import networkx as nx
- def path_graph():
- """Return a path graph of length three."""
- G = nx.path_graph(3, create_using=nx.DiGraph)
- G.graph["name"] = "path"
- nx.freeze(G)
- return G
- def fork_graph():
- """Return a three node fork graph."""
- G = nx.DiGraph(name="fork")
- G.add_edges_from([(0, 1), (0, 2)])
- nx.freeze(G)
- return G
- def collider_graph():
- """Return a collider/v-structure graph with three nodes."""
- G = nx.DiGraph(name="collider")
- G.add_edges_from([(0, 2), (1, 2)])
- nx.freeze(G)
- return G
- def naive_bayes_graph():
- """Return a simply Naive Bayes PGM graph."""
- G = nx.DiGraph(name="naive_bayes")
- G.add_edges_from([(0, 1), (0, 2), (0, 3), (0, 4)])
- nx.freeze(G)
- return G
- def asia_graph():
- """Return the 'Asia' PGM graph."""
- G = nx.DiGraph(name="asia")
- G.add_edges_from(
- [
- ("asia", "tuberculosis"),
- ("smoking", "cancer"),
- ("smoking", "bronchitis"),
- ("tuberculosis", "either"),
- ("cancer", "either"),
- ("either", "xray"),
- ("either", "dyspnea"),
- ("bronchitis", "dyspnea"),
- ]
- )
- nx.freeze(G)
- return G
- @pytest.fixture(name="path_graph")
- def path_graph_fixture():
- return path_graph()
- @pytest.fixture(name="fork_graph")
- def fork_graph_fixture():
- return fork_graph()
- @pytest.fixture(name="collider_graph")
- def collider_graph_fixture():
- return collider_graph()
- @pytest.fixture(name="naive_bayes_graph")
- def naive_bayes_graph_fixture():
- return naive_bayes_graph()
- @pytest.fixture(name="asia_graph")
- def asia_graph_fixture():
- return asia_graph()
- @pytest.mark.parametrize(
- "graph",
- [path_graph(), fork_graph(), collider_graph(), naive_bayes_graph(), asia_graph()],
- )
- def test_markov_condition(graph):
- """Test that the Markov condition holds for each PGM graph."""
- for node in graph.nodes:
- parents = set(graph.predecessors(node))
- non_descendants = graph.nodes - nx.descendants(graph, node) - {node} - parents
- assert nx.d_separated(graph, {node}, non_descendants, parents)
- def test_path_graph_dsep(path_graph):
- """Example-based test of d-separation for path_graph."""
- assert nx.d_separated(path_graph, {0}, {2}, {1})
- assert not nx.d_separated(path_graph, {0}, {2}, {})
- def test_fork_graph_dsep(fork_graph):
- """Example-based test of d-separation for fork_graph."""
- assert nx.d_separated(fork_graph, {1}, {2}, {0})
- assert not nx.d_separated(fork_graph, {1}, {2}, {})
- def test_collider_graph_dsep(collider_graph):
- """Example-based test of d-separation for collider_graph."""
- assert nx.d_separated(collider_graph, {0}, {1}, {})
- assert not nx.d_separated(collider_graph, {0}, {1}, {2})
- def test_naive_bayes_dsep(naive_bayes_graph):
- """Example-based test of d-separation for naive_bayes_graph."""
- for u, v in combinations(range(1, 5), 2):
- assert nx.d_separated(naive_bayes_graph, {u}, {v}, {0})
- assert not nx.d_separated(naive_bayes_graph, {u}, {v}, {})
- def test_asia_graph_dsep(asia_graph):
- """Example-based test of d-separation for asia_graph."""
- assert nx.d_separated(
- asia_graph, {"asia", "smoking"}, {"dyspnea", "xray"}, {"bronchitis", "either"}
- )
- assert nx.d_separated(
- asia_graph, {"tuberculosis", "cancer"}, {"bronchitis"}, {"smoking", "xray"}
- )
- def test_undirected_graphs_are_not_supported():
- """
- Test that undirected graphs are not supported.
- d-separation and its related algorithms do not apply in
- the case of undirected graphs.
- """
- g = nx.path_graph(3, nx.Graph)
- with pytest.raises(nx.NetworkXNotImplemented):
- nx.d_separated(g, {0}, {1}, {2})
- with pytest.raises(nx.NetworkXNotImplemented):
- nx.is_minimal_d_separator(g, {0}, {1}, {2})
- with pytest.raises(nx.NetworkXNotImplemented):
- nx.minimal_d_separator(g, {0}, {1})
- def test_cyclic_graphs_raise_error():
- """
- Test that cycle graphs should cause erroring.
- This is because PGMs assume a directed acyclic graph.
- """
- g = nx.cycle_graph(3, nx.DiGraph)
- with pytest.raises(nx.NetworkXError):
- nx.d_separated(g, {0}, {1}, {2})
- with pytest.raises(nx.NetworkXError):
- nx.minimal_d_separator(g, {0}, {1})
- with pytest.raises(nx.NetworkXError):
- nx.is_minimal_d_separator(g, {0}, {1}, {2})
- def test_invalid_nodes_raise_error(asia_graph):
- """
- Test that graphs that have invalid nodes passed in raise errors.
- """
- with pytest.raises(nx.NodeNotFound):
- nx.d_separated(asia_graph, {0}, {1}, {2})
- with pytest.raises(nx.NodeNotFound):
- nx.is_minimal_d_separator(asia_graph, 0, 1, {2})
- with pytest.raises(nx.NodeNotFound):
- nx.minimal_d_separator(asia_graph, 0, 1)
- def test_minimal_d_separator():
- # Case 1:
- # create a graph A -> B <- C
- # B -> D -> E;
- # B -> F;
- # G -> E;
- edge_list = [("A", "B"), ("C", "B"), ("B", "D"), ("D", "E"), ("B", "F"), ("G", "E")]
- G = nx.DiGraph(edge_list)
- assert not nx.d_separated(G, {"B"}, {"E"}, set())
- # minimal set of the corresponding graph
- # for B and E should be (D,)
- Zmin = nx.minimal_d_separator(G, "B", "E")
- # the minimal separating set should pass the test for minimality
- assert nx.is_minimal_d_separator(G, "B", "E", Zmin)
- assert Zmin == {"D"}
- # Case 2:
- # create a graph A -> B -> C
- # B -> D -> C;
- edge_list = [("A", "B"), ("B", "C"), ("B", "D"), ("D", "C")]
- G = nx.DiGraph(edge_list)
- assert not nx.d_separated(G, {"A"}, {"C"}, set())
- Zmin = nx.minimal_d_separator(G, "A", "C")
- # the minimal separating set should pass the test for minimality
- assert nx.is_minimal_d_separator(G, "A", "C", Zmin)
- assert Zmin == {"B"}
|