utils.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. from collections import OrderedDict
  2. __all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"]
  3. def raises(err, lamda):
  4. try:
  5. lamda()
  6. return False
  7. except err:
  8. return True
  9. def expand_tuples(L):
  10. """
  11. >>> expand_tuples([1, (2, 3)])
  12. [(1, 2), (1, 3)]
  13. >>> expand_tuples([1, 2])
  14. [(1, 2)]
  15. """
  16. if not L:
  17. return [()]
  18. elif not isinstance(L[0], tuple):
  19. rest = expand_tuples(L[1:])
  20. return [(L[0],) + t for t in rest]
  21. else:
  22. rest = expand_tuples(L[1:])
  23. return [(item,) + t for t in rest for item in L[0]]
  24. # Taken from theano/theano/gof/sched.py
  25. # Avoids licensing issues because this was written by Matthew Rocklin
  26. def _toposort(edges):
  27. """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)
  28. inputs:
  29. edges - a dict of the form {a: {b, c}} where b and c depend on a
  30. outputs:
  31. L - an ordered list of nodes that satisfy the dependencies of edges
  32. >>> _toposort({1: (2, 3), 2: (3, )})
  33. [1, 2, 3]
  34. >>> # Closely follows the wikipedia page [2]
  35. >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
  36. >>> # Communications of the ACM
  37. >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms
  38. """
  39. incoming_edges = reverse_dict(edges)
  40. incoming_edges = OrderedDict((k, set(val))
  41. for k, val in incoming_edges.items())
  42. S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
  43. L = []
  44. while S:
  45. n, _ = S.popitem()
  46. L.append(n)
  47. for m in edges.get(n, ()):
  48. assert n in incoming_edges[m]
  49. incoming_edges[m].remove(n)
  50. if not incoming_edges[m]:
  51. S[m] = None
  52. if any(incoming_edges.get(v, None) for v in edges):
  53. raise ValueError("Input has cycles")
  54. return L
  55. def reverse_dict(d):
  56. """Reverses direction of dependence dict
  57. >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
  58. >>> reverse_dict(d) # doctest: +SKIP
  59. {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
  60. :note: dict order are not deterministic. As we iterate on the
  61. input dict, it make the output of this function depend on the
  62. dict order. So this function output order should be considered
  63. as undeterministic.
  64. """
  65. result = OrderedDict() # type: ignore[var-annotated]
  66. for key in d:
  67. for val in d[key]:
  68. result[val] = result.get(val, tuple()) + (key, )
  69. return result
  70. # Taken from toolz
  71. # Avoids licensing issues because this version was authored by Matthew Rocklin
  72. def groupby(func, seq):
  73. """ Group a collection by a key function
  74. >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
  75. >>> groupby(len, names) # doctest: +SKIP
  76. {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
  77. >>> iseven = lambda x: x % 2 == 0
  78. >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
  79. {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
  80. See Also:
  81. ``countby``
  82. """
  83. d = OrderedDict() # type: ignore[var-annotated]
  84. for item in seq:
  85. key = func(item)
  86. if key not in d:
  87. d[key] = list()
  88. d[key].append(item)
  89. return d
  90. def typename(type):
  91. """Get the name of `type`.
  92. Parameters
  93. ----------
  94. type : Union[Type, Tuple[Type]]
  95. Returns
  96. -------
  97. str
  98. The name of `type` or a tuple of the names of the types in `type`.
  99. Examples
  100. --------
  101. >>> typename(int)
  102. 'int'
  103. >>> typename((int, float))
  104. '(int, float)'
  105. """
  106. try:
  107. return type.__name__
  108. except AttributeError:
  109. if len(type) == 1:
  110. return typename(*type)
  111. return '(%s)' % ', '.join(map(typename, type))