123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441 |
- """
- Algorithm for testing d-separation in DAGs.
- *d-separation* is a test for conditional independence in probability
- distributions that can be factorized using DAGs. It is a purely
- graphical test that uses the underlying graph and makes no reference
- to the actual distribution parameters. See [1]_ for a formal
- definition.
- The implementation is based on the conceptually simple linear time
- algorithm presented in [2]_. Refer to [3]_, [4]_ for a couple of
- alternative algorithms.
- Here, we provide a brief overview of d-separation and related concepts that
- are relevant for understanding it:
- Blocking paths
- --------------
- Before we overview, we introduce the following terminology to describe paths:
- - "open" path: A path between two nodes that can be traversed
- - "blocked" path: A path between two nodes that cannot be traversed
- A **collider** is a triplet of nodes along a path that is like the following:
- ``... u -> c <- v ...``), where 'c' is a common successor of ``u`` and ``v``. A path
- through a collider is considered "blocked". When
- a node that is a collider, or a descendant of a collider is included in
- the d-separating set, then the path through that collider node is "open". If the
- path through the collider node is open, then we will call this node an open collider.
- The d-separation set blocks the paths between ``u`` and ``v``. If you include colliders,
- or their descendant nodes in the d-separation set, then those colliders will open up,
- enabling a path to be traversed if it is not blocked some other way.
- Illustration of D-separation with examples
- ------------------------------------------
- For a pair of two nodes, ``u`` and ``v``, all paths are considered open if
- there is a path between ``u`` and ``v`` that is not blocked. That means, there is an open
- path between ``u`` and ``v`` that does not encounter a collider, or a variable in the
- d-separating set.
- For example, if the d-separating set is the empty set, then the following paths are
- unblocked between ``u`` and ``v``:
- - u <- z -> v
- - u -> w -> ... -> z -> v
- If for example, 'z' is in the d-separating set, then 'z' blocks those paths
- between ``u`` and ``v``.
- Colliders block a path by default if they and their descendants are not included
- in the d-separating set. An example of a path that is blocked when the d-separating
- set is empty is:
- - u -> w -> ... -> z <- v
- because 'z' is a collider in this path and 'z' is not in the d-separating set. However,
- if 'z' or a descendant of 'z' is included in the d-separating set, then the path through
- the collider at 'z' (... -> z <- ...) is now "open".
- D-separation is concerned with blocking all paths between u and v. Therefore, a
- d-separating set between ``u`` and ``v`` is one where all paths are blocked.
- D-separation and its applications in probability
- ------------------------------------------------
- D-separation is commonly used in probabilistic graphical models. D-separation
- connects the idea of probabilistic "dependence" with separation in a graph. If
- one assumes the causal Markov condition [5]_, then d-separation implies conditional
- independence in probability distributions.
- Examples
- --------
- >>>
- >>> # HMM graph with five states and observation nodes
- ... g = nx.DiGraph()
- >>> g.add_edges_from(
- ... [
- ... ("S1", "S2"),
- ... ("S2", "S3"),
- ... ("S3", "S4"),
- ... ("S4", "S5"),
- ... ("S1", "O1"),
- ... ("S2", "O2"),
- ... ("S3", "O3"),
- ... ("S4", "O4"),
- ... ("S5", "O5"),
- ... ]
- ... )
- >>>
- >>> # states/obs before 'S3' are d-separated from states/obs after 'S3'
- ... nx.d_separated(g, {"S1", "S2", "O1", "O2"}, {"S4", "S5", "O4", "O5"}, {"S3"})
- True
- References
- ----------
- .. [1] Pearl, J. (2009). Causality. Cambridge: Cambridge University Press.
- .. [2] Darwiche, A. (2009). Modeling and reasoning with Bayesian networks.
- Cambridge: Cambridge University Press.
- .. [3] Shachter, R. D. (1998).
- Bayes-ball: rational pastime (for determining irrelevance and requisite
- information in belief networks and influence diagrams).
- In , Proceedings of the Fourteenth Conference on Uncertainty in Artificial
- Intelligence (pp. 480–487).
- San Francisco, CA, USA: Morgan Kaufmann Publishers Inc.
- .. [4] Koller, D., & Friedman, N. (2009).
- Probabilistic graphical models: principles and techniques. The MIT Press.
- .. [5] https://en.wikipedia.org/wiki/Causal_Markov_condition
- """
- from collections import deque
- import networkx as nx
- from networkx.utils import UnionFind, not_implemented_for
- __all__ = ["d_separated", "minimal_d_separator", "is_minimal_d_separator"]
- @not_implemented_for("undirected")
- def d_separated(G, x, y, z):
- """
- Return whether node sets ``x`` and ``y`` are d-separated by ``z``.
- Parameters
- ----------
- G : graph
- A NetworkX DAG.
- x : set
- First set of nodes in ``G``.
- y : set
- Second set of nodes in ``G``.
- z : set
- Set of conditioning nodes in ``G``. Can be empty set.
- Returns
- -------
- b : bool
- A boolean that is true if ``x`` is d-separated from ``y`` given ``z`` in ``G``.
- Raises
- ------
- NetworkXError
- The *d-separation* test is commonly used with directed
- graphical models which are acyclic. Accordingly, the algorithm
- raises a :exc:`NetworkXError` if the input graph is not a DAG.
- NodeNotFound
- If any of the input nodes are not found in the graph,
- a :exc:`NodeNotFound` exception is raised.
- Notes
- -----
- A d-separating set in a DAG is a set of nodes that
- blocks all paths between the two sets. Nodes in `z`
- block a path if they are part of the path and are not a collider,
- or a descendant of a collider. A collider structure along a path
- is ``... -> c <- ...`` where ``c`` is the collider node.
- https://en.wikipedia.org/wiki/Bayesian_network#d-separation
- """
- if not nx.is_directed_acyclic_graph(G):
- raise nx.NetworkXError("graph should be directed acyclic")
- union_xyz = x.union(y).union(z)
- if any(n not in G.nodes for n in union_xyz):
- raise nx.NodeNotFound("one or more specified nodes not found in the graph")
- G_copy = G.copy()
- # transform the graph by removing leaves that are not in x | y | z
- # until no more leaves can be removed.
- leaves = deque([n for n in G_copy.nodes if G_copy.out_degree[n] == 0])
- while len(leaves) > 0:
- leaf = leaves.popleft()
- if leaf not in union_xyz:
- for p in G_copy.predecessors(leaf):
- if G_copy.out_degree[p] == 1:
- leaves.append(p)
- G_copy.remove_node(leaf)
- # transform the graph by removing outgoing edges from the
- # conditioning set.
- edges_to_remove = list(G_copy.out_edges(z))
- G_copy.remove_edges_from(edges_to_remove)
- # use disjoint-set data structure to check if any node in `x`
- # occurs in the same weakly connected component as a node in `y`.
- disjoint_set = UnionFind(G_copy.nodes())
- for component in nx.weakly_connected_components(G_copy):
- disjoint_set.union(*component)
- disjoint_set.union(*x)
- disjoint_set.union(*y)
- if x and y and disjoint_set[next(iter(x))] == disjoint_set[next(iter(y))]:
- return False
- else:
- return True
- @not_implemented_for("undirected")
- def minimal_d_separator(G, u, v):
- """Compute a minimal d-separating set between 'u' and 'v'.
- A d-separating set in a DAG is a set of nodes that blocks all paths
- between the two nodes, 'u' and 'v'. This function
- constructs a d-separating set that is "minimal", meaning it is the smallest
- d-separating set for 'u' and 'v'. This is not necessarily
- unique. For more details, see Notes.
- Parameters
- ----------
- G : graph
- A networkx DAG.
- u : node
- A node in the graph, G.
- v : node
- A node in the graph, G.
- Raises
- ------
- NetworkXError
- Raises a :exc:`NetworkXError` if the input graph is not a DAG.
- NodeNotFound
- If any of the input nodes are not found in the graph,
- a :exc:`NodeNotFound` exception is raised.
- References
- ----------
- .. [1] Tian, J., & Paz, A. (1998). Finding Minimal D-separators.
- Notes
- -----
- This function only finds ``a`` minimal d-separator. It does not guarantee
- uniqueness, since in a DAG there may be more than one minimal d-separator
- between two nodes. Moreover, this only checks for minimal separators
- between two nodes, not two sets. Finding minimal d-separators between
- two sets of nodes is not supported.
- Uses the algorithm presented in [1]_. The complexity of the algorithm
- is :math:`O(|E_{An}^m|)`, where :math:`|E_{An}^m|` stands for the
- number of edges in the moralized graph of the sub-graph consisting
- of only the ancestors of 'u' and 'v'. For full details, see [1]_.
- The algorithm works by constructing the moral graph consisting of just
- the ancestors of `u` and `v`. Then it constructs a candidate for
- a separating set ``Z'`` from the predecessors of `u` and `v`.
- Then BFS is run starting from `u` and marking nodes
- found from ``Z'`` and calling those nodes ``Z''``.
- Then BFS is run again starting from `v` and marking nodes if they are
- present in ``Z''``. Those marked nodes are the returned minimal
- d-separating set.
- https://en.wikipedia.org/wiki/Bayesian_network#d-separation
- """
- if not nx.is_directed_acyclic_graph(G):
- raise nx.NetworkXError("graph should be directed acyclic")
- union_uv = {u, v}
- if any(n not in G.nodes for n in union_uv):
- raise nx.NodeNotFound("one or more specified nodes not found in the graph")
- # first construct the set of ancestors of X and Y
- x_anc = nx.ancestors(G, u)
- y_anc = nx.ancestors(G, v)
- D_anc_xy = x_anc.union(y_anc)
- D_anc_xy.update((u, v))
- # second, construct the moralization of the subgraph of Anc(X,Y)
- moral_G = nx.moral_graph(G.subgraph(D_anc_xy))
- # find a separating set Z' in moral_G
- Z_prime = set(G.predecessors(u)).union(set(G.predecessors(v)))
- # perform BFS on the graph from 'x' to mark
- Z_dprime = _bfs_with_marks(moral_G, u, Z_prime)
- Z = _bfs_with_marks(moral_G, v, Z_dprime)
- return Z
- @not_implemented_for("undirected")
- def is_minimal_d_separator(G, u, v, z):
- """Determine if a d-separating set is minimal.
- A d-separating set, `z`, in a DAG is a set of nodes that blocks
- all paths between the two nodes, `u` and `v`. This function
- verifies that a set is "minimal", meaning there is no smaller
- d-separating set between the two nodes.
- Parameters
- ----------
- G : nx.DiGraph
- The graph.
- u : node
- A node in the graph.
- v : node
- A node in the graph.
- z : Set of nodes
- The set of nodes to check if it is a minimal d-separating set.
- Returns
- -------
- bool
- Whether or not the `z` separating set is minimal.
- Raises
- ------
- NetworkXError
- Raises a :exc:`NetworkXError` if the input graph is not a DAG.
- NodeNotFound
- If any of the input nodes are not found in the graph,
- a :exc:`NodeNotFound` exception is raised.
- References
- ----------
- .. [1] Tian, J., & Paz, A. (1998). Finding Minimal D-separators.
- Notes
- -----
- This function only works on verifying a d-separating set is minimal
- between two nodes. To verify that a d-separating set is minimal between
- two sets of nodes is not supported.
- Uses algorithm 2 presented in [1]_. The complexity of the algorithm
- is :math:`O(|E_{An}^m|)`, where :math:`|E_{An}^m|` stands for the
- number of edges in the moralized graph of the sub-graph consisting
- of only the ancestors of ``u`` and ``v``.
- The algorithm works by constructing the moral graph consisting of just
- the ancestors of `u` and `v`. First, it performs BFS on the moral graph
- starting from `u` and marking any nodes it encounters that are part of
- the separating set, `z`. If a node is marked, then it does not continue
- along that path. In the second stage, BFS with markings is repeated on the
- moral graph starting from `v`. If at any stage, any node in `z` is
- not marked, then `z` is considered not minimal. If the end of the algorithm
- is reached, then `z` is minimal.
- For full details, see [1]_.
- https://en.wikipedia.org/wiki/Bayesian_network#d-separation
- """
- if not nx.is_directed_acyclic_graph(G):
- raise nx.NetworkXError("graph should be directed acyclic")
- union_uv = {u, v}
- union_uv.update(z)
- if any(n not in G.nodes for n in union_uv):
- raise nx.NodeNotFound("one or more specified nodes not found in the graph")
- x_anc = nx.ancestors(G, u)
- y_anc = nx.ancestors(G, v)
- xy_anc = x_anc.union(y_anc)
- # if Z contains any node which is not in ancestors of X or Y
- # then it is definitely not minimal
- if any(node not in xy_anc for node in z):
- return False
- D_anc_xy = x_anc.union(y_anc)
- D_anc_xy.update((u, v))
- # second, construct the moralization of the subgraph
- moral_G = nx.moral_graph(G.subgraph(D_anc_xy))
- # start BFS from X
- marks = _bfs_with_marks(moral_G, u, z)
- # if not all the Z is marked, then the set is not minimal
- if any(node not in marks for node in z):
- return False
- # similarly, start BFS from Y and check the marks
- marks = _bfs_with_marks(moral_G, v, z)
- # if not all the Z is marked, then the set is not minimal
- if any(node not in marks for node in z):
- return False
- return True
- @not_implemented_for("directed")
- def _bfs_with_marks(G, start_node, check_set):
- """Breadth-first-search with markings.
- Performs BFS starting from ``start_node`` and whenever a node
- inside ``check_set`` is met, it is "marked". Once a node is marked,
- BFS does not continue along that path. The resulting marked nodes
- are returned.
- Parameters
- ----------
- G : nx.Graph
- An undirected graph.
- start_node : node
- The start of the BFS.
- check_set : set
- The set of nodes to check against.
- Returns
- -------
- marked : set
- A set of nodes that were marked.
- """
- visited = {}
- marked = set()
- queue = []
- visited[start_node] = None
- queue.append(start_node)
- while queue:
- m = queue.pop(0)
- for nbr in G.neighbors(m):
- if nbr not in visited:
- # memoize where we visited so far
- visited[nbr] = None
- # mark the node in Z' and do not continue along that path
- if nbr in check_set:
- marked.add(nbr)
- else:
- queue.append(nbr)
- return marked
|