utils.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. __all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"]
  2. def hashable(x):
  3. try:
  4. hash(x)
  5. return True
  6. except TypeError:
  7. return False
  8. def transitive_get(key, d):
  9. """ Transitive dict.get
  10. >>> d = {1: 2, 2: 3, 3: 4}
  11. >>> d.get(1)
  12. 2
  13. >>> transitive_get(1, d)
  14. 4
  15. """
  16. while hashable(key) and key in d:
  17. key = d[key]
  18. return key
  19. def raises(err, lamda):
  20. try:
  21. lamda()
  22. return False
  23. except err:
  24. return True
  25. # Taken from theano/theano/gof/sched.py
  26. # Avoids licensing issues because this was written by Matthew Rocklin
  27. def _toposort(edges):
  28. """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)
  29. inputs:
  30. edges - a dict of the form {a: {b, c}} where b and c depend on a
  31. outputs:
  32. L - an ordered list of nodes that satisfy the dependencies of edges
  33. >>> # xdoctest: +SKIP
  34. >>> _toposort({1: (2, 3), 2: (3, )})
  35. [1, 2, 3]
  36. Closely follows the wikipedia page [2]
  37. [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
  38. Communications of the ACM
  39. [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
  40. """
  41. incoming_edges = reverse_dict(edges)
  42. incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
  43. S = ({v for v in edges if v not in incoming_edges})
  44. L = []
  45. while S:
  46. n = S.pop()
  47. L.append(n)
  48. for m in edges.get(n, ()):
  49. assert n in incoming_edges[m]
  50. incoming_edges[m].remove(n)
  51. if not incoming_edges[m]:
  52. S.add(m)
  53. if any(incoming_edges.get(v, None) for v in edges):
  54. raise ValueError("Input has cycles")
  55. return L
  56. def reverse_dict(d):
  57. """Reverses direction of dependence dict
  58. >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
  59. >>> reverse_dict(d) # doctest: +SKIP
  60. {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
  61. :note: dict order are not deterministic. As we iterate on the
  62. input dict, it make the output of this function depend on the
  63. dict order. So this function output order should be considered
  64. as undeterministic.
  65. """
  66. result = {} # type: ignore[var-annotated]
  67. for key in d:
  68. for val in d[key]:
  69. result[val] = result.get(val, tuple()) + (key, )
  70. return result
  71. def xfail(func):
  72. try:
  73. func()
  74. raise Exception("XFailed test passed") # pragma:nocover
  75. except Exception:
  76. pass
  77. def freeze(d):
  78. """ Freeze container to hashable form
  79. >>> freeze(1)
  80. 1
  81. >>> freeze([1, 2])
  82. (1, 2)
  83. >>> freeze({1: 2}) # doctest: +SKIP
  84. frozenset([(1, 2)])
  85. """
  86. if isinstance(d, dict):
  87. return frozenset(map(freeze, d.items()))
  88. if isinstance(d, set):
  89. return frozenset(map(freeze, d))
  90. if isinstance(d, (tuple, list)):
  91. return tuple(map(freeze, d))
  92. return d