d_separation.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441
  1. """
  2. Algorithm for testing d-separation in DAGs.
  3. *d-separation* is a test for conditional independence in probability
  4. distributions that can be factorized using DAGs. It is a purely
  5. graphical test that uses the underlying graph and makes no reference
  6. to the actual distribution parameters. See [1]_ for a formal
  7. definition.
  8. The implementation is based on the conceptually simple linear time
  9. algorithm presented in [2]_. Refer to [3]_, [4]_ for a couple of
  10. alternative algorithms.
  11. Here, we provide a brief overview of d-separation and related concepts that
  12. are relevant for understanding it:
  13. Blocking paths
  14. --------------
  15. Before we overview, we introduce the following terminology to describe paths:
  16. - "open" path: A path between two nodes that can be traversed
  17. - "blocked" path: A path between two nodes that cannot be traversed
  18. A **collider** is a triplet of nodes along a path that is like the following:
  19. ``... u -> c <- v ...``), where 'c' is a common successor of ``u`` and ``v``. A path
  20. through a collider is considered "blocked". When
  21. a node that is a collider, or a descendant of a collider is included in
  22. the d-separating set, then the path through that collider node is "open". If the
  23. path through the collider node is open, then we will call this node an open collider.
  24. The d-separation set blocks the paths between ``u`` and ``v``. If you include colliders,
  25. or their descendant nodes in the d-separation set, then those colliders will open up,
  26. enabling a path to be traversed if it is not blocked some other way.
  27. Illustration of D-separation with examples
  28. ------------------------------------------
  29. For a pair of two nodes, ``u`` and ``v``, all paths are considered open if
  30. there is a path between ``u`` and ``v`` that is not blocked. That means, there is an open
  31. path between ``u`` and ``v`` that does not encounter a collider, or a variable in the
  32. d-separating set.
  33. For example, if the d-separating set is the empty set, then the following paths are
  34. unblocked between ``u`` and ``v``:
  35. - u <- z -> v
  36. - u -> w -> ... -> z -> v
  37. If for example, 'z' is in the d-separating set, then 'z' blocks those paths
  38. between ``u`` and ``v``.
  39. Colliders block a path by default if they and their descendants are not included
  40. in the d-separating set. An example of a path that is blocked when the d-separating
  41. set is empty is:
  42. - u -> w -> ... -> z <- v
  43. because 'z' is a collider in this path and 'z' is not in the d-separating set. However,
  44. if 'z' or a descendant of 'z' is included in the d-separating set, then the path through
  45. the collider at 'z' (... -> z <- ...) is now "open".
  46. D-separation is concerned with blocking all paths between u and v. Therefore, a
  47. d-separating set between ``u`` and ``v`` is one where all paths are blocked.
  48. D-separation and its applications in probability
  49. ------------------------------------------------
  50. D-separation is commonly used in probabilistic graphical models. D-separation
  51. connects the idea of probabilistic "dependence" with separation in a graph. If
  52. one assumes the causal Markov condition [5]_, then d-separation implies conditional
  53. independence in probability distributions.
  54. Examples
  55. --------
  56. >>>
  57. >>> # HMM graph with five states and observation nodes
  58. ... g = nx.DiGraph()
  59. >>> g.add_edges_from(
  60. ... [
  61. ... ("S1", "S2"),
  62. ... ("S2", "S3"),
  63. ... ("S3", "S4"),
  64. ... ("S4", "S5"),
  65. ... ("S1", "O1"),
  66. ... ("S2", "O2"),
  67. ... ("S3", "O3"),
  68. ... ("S4", "O4"),
  69. ... ("S5", "O5"),
  70. ... ]
  71. ... )
  72. >>>
  73. >>> # states/obs before 'S3' are d-separated from states/obs after 'S3'
  74. ... nx.d_separated(g, {"S1", "S2", "O1", "O2"}, {"S4", "S5", "O4", "O5"}, {"S3"})
  75. True
  76. References
  77. ----------
  78. .. [1] Pearl, J. (2009). Causality. Cambridge: Cambridge University Press.
  79. .. [2] Darwiche, A. (2009). Modeling and reasoning with Bayesian networks.
  80. Cambridge: Cambridge University Press.
  81. .. [3] Shachter, R. D. (1998).
  82. Bayes-ball: rational pastime (for determining irrelevance and requisite
  83. information in belief networks and influence diagrams).
  84. In , Proceedings of the Fourteenth Conference on Uncertainty in Artificial
  85. Intelligence (pp. 480–487).
  86. San Francisco, CA, USA: Morgan Kaufmann Publishers Inc.
  87. .. [4] Koller, D., & Friedman, N. (2009).
  88. Probabilistic graphical models: principles and techniques. The MIT Press.
  89. .. [5] https://en.wikipedia.org/wiki/Causal_Markov_condition
  90. """
  91. from collections import deque
  92. import networkx as nx
  93. from networkx.utils import UnionFind, not_implemented_for
  94. __all__ = ["d_separated", "minimal_d_separator", "is_minimal_d_separator"]
  95. @not_implemented_for("undirected")
  96. def d_separated(G, x, y, z):
  97. """
  98. Return whether node sets ``x`` and ``y`` are d-separated by ``z``.
  99. Parameters
  100. ----------
  101. G : graph
  102. A NetworkX DAG.
  103. x : set
  104. First set of nodes in ``G``.
  105. y : set
  106. Second set of nodes in ``G``.
  107. z : set
  108. Set of conditioning nodes in ``G``. Can be empty set.
  109. Returns
  110. -------
  111. b : bool
  112. A boolean that is true if ``x`` is d-separated from ``y`` given ``z`` in ``G``.
  113. Raises
  114. ------
  115. NetworkXError
  116. The *d-separation* test is commonly used with directed
  117. graphical models which are acyclic. Accordingly, the algorithm
  118. raises a :exc:`NetworkXError` if the input graph is not a DAG.
  119. NodeNotFound
  120. If any of the input nodes are not found in the graph,
  121. a :exc:`NodeNotFound` exception is raised.
  122. Notes
  123. -----
  124. A d-separating set in a DAG is a set of nodes that
  125. blocks all paths between the two sets. Nodes in `z`
  126. block a path if they are part of the path and are not a collider,
  127. or a descendant of a collider. A collider structure along a path
  128. is ``... -> c <- ...`` where ``c`` is the collider node.
  129. https://en.wikipedia.org/wiki/Bayesian_network#d-separation
  130. """
  131. if not nx.is_directed_acyclic_graph(G):
  132. raise nx.NetworkXError("graph should be directed acyclic")
  133. union_xyz = x.union(y).union(z)
  134. if any(n not in G.nodes for n in union_xyz):
  135. raise nx.NodeNotFound("one or more specified nodes not found in the graph")
  136. G_copy = G.copy()
  137. # transform the graph by removing leaves that are not in x | y | z
  138. # until no more leaves can be removed.
  139. leaves = deque([n for n in G_copy.nodes if G_copy.out_degree[n] == 0])
  140. while len(leaves) > 0:
  141. leaf = leaves.popleft()
  142. if leaf not in union_xyz:
  143. for p in G_copy.predecessors(leaf):
  144. if G_copy.out_degree[p] == 1:
  145. leaves.append(p)
  146. G_copy.remove_node(leaf)
  147. # transform the graph by removing outgoing edges from the
  148. # conditioning set.
  149. edges_to_remove = list(G_copy.out_edges(z))
  150. G_copy.remove_edges_from(edges_to_remove)
  151. # use disjoint-set data structure to check if any node in `x`
  152. # occurs in the same weakly connected component as a node in `y`.
  153. disjoint_set = UnionFind(G_copy.nodes())
  154. for component in nx.weakly_connected_components(G_copy):
  155. disjoint_set.union(*component)
  156. disjoint_set.union(*x)
  157. disjoint_set.union(*y)
  158. if x and y and disjoint_set[next(iter(x))] == disjoint_set[next(iter(y))]:
  159. return False
  160. else:
  161. return True
  162. @not_implemented_for("undirected")
  163. def minimal_d_separator(G, u, v):
  164. """Compute a minimal d-separating set between 'u' and 'v'.
  165. A d-separating set in a DAG is a set of nodes that blocks all paths
  166. between the two nodes, 'u' and 'v'. This function
  167. constructs a d-separating set that is "minimal", meaning it is the smallest
  168. d-separating set for 'u' and 'v'. This is not necessarily
  169. unique. For more details, see Notes.
  170. Parameters
  171. ----------
  172. G : graph
  173. A networkx DAG.
  174. u : node
  175. A node in the graph, G.
  176. v : node
  177. A node in the graph, G.
  178. Raises
  179. ------
  180. NetworkXError
  181. Raises a :exc:`NetworkXError` if the input graph is not a DAG.
  182. NodeNotFound
  183. If any of the input nodes are not found in the graph,
  184. a :exc:`NodeNotFound` exception is raised.
  185. References
  186. ----------
  187. .. [1] Tian, J., & Paz, A. (1998). Finding Minimal D-separators.
  188. Notes
  189. -----
  190. This function only finds ``a`` minimal d-separator. It does not guarantee
  191. uniqueness, since in a DAG there may be more than one minimal d-separator
  192. between two nodes. Moreover, this only checks for minimal separators
  193. between two nodes, not two sets. Finding minimal d-separators between
  194. two sets of nodes is not supported.
  195. Uses the algorithm presented in [1]_. The complexity of the algorithm
  196. is :math:`O(|E_{An}^m|)`, where :math:`|E_{An}^m|` stands for the
  197. number of edges in the moralized graph of the sub-graph consisting
  198. of only the ancestors of 'u' and 'v'. For full details, see [1]_.
  199. The algorithm works by constructing the moral graph consisting of just
  200. the ancestors of `u` and `v`. Then it constructs a candidate for
  201. a separating set ``Z'`` from the predecessors of `u` and `v`.
  202. Then BFS is run starting from `u` and marking nodes
  203. found from ``Z'`` and calling those nodes ``Z''``.
  204. Then BFS is run again starting from `v` and marking nodes if they are
  205. present in ``Z''``. Those marked nodes are the returned minimal
  206. d-separating set.
  207. https://en.wikipedia.org/wiki/Bayesian_network#d-separation
  208. """
  209. if not nx.is_directed_acyclic_graph(G):
  210. raise nx.NetworkXError("graph should be directed acyclic")
  211. union_uv = {u, v}
  212. if any(n not in G.nodes for n in union_uv):
  213. raise nx.NodeNotFound("one or more specified nodes not found in the graph")
  214. # first construct the set of ancestors of X and Y
  215. x_anc = nx.ancestors(G, u)
  216. y_anc = nx.ancestors(G, v)
  217. D_anc_xy = x_anc.union(y_anc)
  218. D_anc_xy.update((u, v))
  219. # second, construct the moralization of the subgraph of Anc(X,Y)
  220. moral_G = nx.moral_graph(G.subgraph(D_anc_xy))
  221. # find a separating set Z' in moral_G
  222. Z_prime = set(G.predecessors(u)).union(set(G.predecessors(v)))
  223. # perform BFS on the graph from 'x' to mark
  224. Z_dprime = _bfs_with_marks(moral_G, u, Z_prime)
  225. Z = _bfs_with_marks(moral_G, v, Z_dprime)
  226. return Z
  227. @not_implemented_for("undirected")
  228. def is_minimal_d_separator(G, u, v, z):
  229. """Determine if a d-separating set is minimal.
  230. A d-separating set, `z`, in a DAG is a set of nodes that blocks
  231. all paths between the two nodes, `u` and `v`. This function
  232. verifies that a set is "minimal", meaning there is no smaller
  233. d-separating set between the two nodes.
  234. Parameters
  235. ----------
  236. G : nx.DiGraph
  237. The graph.
  238. u : node
  239. A node in the graph.
  240. v : node
  241. A node in the graph.
  242. z : Set of nodes
  243. The set of nodes to check if it is a minimal d-separating set.
  244. Returns
  245. -------
  246. bool
  247. Whether or not the `z` separating set is minimal.
  248. Raises
  249. ------
  250. NetworkXError
  251. Raises a :exc:`NetworkXError` if the input graph is not a DAG.
  252. NodeNotFound
  253. If any of the input nodes are not found in the graph,
  254. a :exc:`NodeNotFound` exception is raised.
  255. References
  256. ----------
  257. .. [1] Tian, J., & Paz, A. (1998). Finding Minimal D-separators.
  258. Notes
  259. -----
  260. This function only works on verifying a d-separating set is minimal
  261. between two nodes. To verify that a d-separating set is minimal between
  262. two sets of nodes is not supported.
  263. Uses algorithm 2 presented in [1]_. The complexity of the algorithm
  264. is :math:`O(|E_{An}^m|)`, where :math:`|E_{An}^m|` stands for the
  265. number of edges in the moralized graph of the sub-graph consisting
  266. of only the ancestors of ``u`` and ``v``.
  267. The algorithm works by constructing the moral graph consisting of just
  268. the ancestors of `u` and `v`. First, it performs BFS on the moral graph
  269. starting from `u` and marking any nodes it encounters that are part of
  270. the separating set, `z`. If a node is marked, then it does not continue
  271. along that path. In the second stage, BFS with markings is repeated on the
  272. moral graph starting from `v`. If at any stage, any node in `z` is
  273. not marked, then `z` is considered not minimal. If the end of the algorithm
  274. is reached, then `z` is minimal.
  275. For full details, see [1]_.
  276. https://en.wikipedia.org/wiki/Bayesian_network#d-separation
  277. """
  278. if not nx.is_directed_acyclic_graph(G):
  279. raise nx.NetworkXError("graph should be directed acyclic")
  280. union_uv = {u, v}
  281. union_uv.update(z)
  282. if any(n not in G.nodes for n in union_uv):
  283. raise nx.NodeNotFound("one or more specified nodes not found in the graph")
  284. x_anc = nx.ancestors(G, u)
  285. y_anc = nx.ancestors(G, v)
  286. xy_anc = x_anc.union(y_anc)
  287. # if Z contains any node which is not in ancestors of X or Y
  288. # then it is definitely not minimal
  289. if any(node not in xy_anc for node in z):
  290. return False
  291. D_anc_xy = x_anc.union(y_anc)
  292. D_anc_xy.update((u, v))
  293. # second, construct the moralization of the subgraph
  294. moral_G = nx.moral_graph(G.subgraph(D_anc_xy))
  295. # start BFS from X
  296. marks = _bfs_with_marks(moral_G, u, z)
  297. # if not all the Z is marked, then the set is not minimal
  298. if any(node not in marks for node in z):
  299. return False
  300. # similarly, start BFS from Y and check the marks
  301. marks = _bfs_with_marks(moral_G, v, z)
  302. # if not all the Z is marked, then the set is not minimal
  303. if any(node not in marks for node in z):
  304. return False
  305. return True
  306. @not_implemented_for("directed")
  307. def _bfs_with_marks(G, start_node, check_set):
  308. """Breadth-first-search with markings.
  309. Performs BFS starting from ``start_node`` and whenever a node
  310. inside ``check_set`` is met, it is "marked". Once a node is marked,
  311. BFS does not continue along that path. The resulting marked nodes
  312. are returned.
  313. Parameters
  314. ----------
  315. G : nx.Graph
  316. An undirected graph.
  317. start_node : node
  318. The start of the BFS.
  319. check_set : set
  320. The set of nodes to check against.
  321. Returns
  322. -------
  323. marked : set
  324. A set of nodes that were marked.
  325. """
  326. visited = {}
  327. marked = set()
  328. queue = []
  329. visited[start_node] = None
  330. queue.append(start_node)
  331. while queue:
  332. m = queue.pop(0)
  333. for nbr in G.neighbors(m):
  334. if nbr not in visited:
  335. # memoize where we visited so far
  336. visited[nbr] = None
  337. # mark the node in Z' and do not continue along that path
  338. if nbr in check_set:
  339. marked.add(nbr)
  340. else:
  341. queue.append(nbr)
  342. return marked