unification_tools.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. import collections
  2. import operator
  3. from functools import reduce
  4. from collections.abc import Mapping
  5. __all__ = ('merge', 'merge_with', 'valmap', 'keymap', 'itemmap',
  6. 'valfilter', 'keyfilter', 'itemfilter',
  7. 'assoc', 'dissoc', 'assoc_in', 'update_in', 'get_in')
  8. def _get_factory(f, kwargs):
  9. factory = kwargs.pop('factory', dict)
  10. if kwargs:
  11. raise TypeError("{}() got an unexpected keyword argument "
  12. "'{}'".format(f.__name__, kwargs.popitem()[0]))
  13. return factory
  14. def merge(*dicts, **kwargs):
  15. """ Merge a collection of dictionaries
  16. >>> merge({1: 'one'}, {2: 'two'})
  17. {1: 'one', 2: 'two'}
  18. Later dictionaries have precedence
  19. >>> merge({1: 2, 3: 4}, {3: 3, 4: 4})
  20. {1: 2, 3: 3, 4: 4}
  21. See Also:
  22. merge_with
  23. """
  24. if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
  25. dicts = dicts[0]
  26. factory = _get_factory(merge, kwargs)
  27. rv = factory()
  28. for d in dicts:
  29. rv.update(d)
  30. return rv
  31. def merge_with(func, *dicts, **kwargs):
  32. """ Merge dictionaries and apply function to combined values
  33. A key may occur in more than one dict, and all values mapped from the key
  34. will be passed to the function as a list, such as func([val1, val2, ...]).
  35. >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20})
  36. {1: 11, 2: 22}
  37. >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP
  38. {1: 1, 2: 2, 3: 30}
  39. See Also:
  40. merge
  41. """
  42. if len(dicts) == 1 and not isinstance(dicts[0], Mapping):
  43. dicts = dicts[0]
  44. factory = _get_factory(merge_with, kwargs)
  45. result = factory()
  46. for d in dicts:
  47. for k, v in d.items():
  48. if k not in result:
  49. result[k] = [v]
  50. else:
  51. result[k].append(v)
  52. return valmap(func, result, factory)
  53. def valmap(func, d, factory=dict):
  54. """ Apply function to values of dictionary
  55. >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
  56. >>> valmap(sum, bills) # doctest: +SKIP
  57. {'Alice': 65, 'Bob': 45}
  58. See Also:
  59. keymap
  60. itemmap
  61. """
  62. rv = factory()
  63. rv.update(zip(d.keys(), map(func, d.values())))
  64. return rv
  65. def keymap(func, d, factory=dict):
  66. """ Apply function to keys of dictionary
  67. >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]}
  68. >>> keymap(str.lower, bills) # doctest: +SKIP
  69. {'alice': [20, 15, 30], 'bob': [10, 35]}
  70. See Also:
  71. valmap
  72. itemmap
  73. """
  74. rv = factory()
  75. rv.update(zip(map(func, d.keys()), d.values()))
  76. return rv
  77. def itemmap(func, d, factory=dict):
  78. """ Apply function to items of dictionary
  79. >>> accountids = {"Alice": 10, "Bob": 20}
  80. >>> itemmap(reversed, accountids) # doctest: +SKIP
  81. {10: "Alice", 20: "Bob"}
  82. See Also:
  83. keymap
  84. valmap
  85. """
  86. rv = factory()
  87. rv.update(map(func, d.items()))
  88. return rv
  89. def valfilter(predicate, d, factory=dict):
  90. """ Filter items in dictionary by value
  91. >>> iseven = lambda x: x % 2 == 0
  92. >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
  93. >>> valfilter(iseven, d)
  94. {1: 2, 3: 4}
  95. See Also:
  96. keyfilter
  97. itemfilter
  98. valmap
  99. """
  100. rv = factory()
  101. for k, v in d.items():
  102. if predicate(v):
  103. rv[k] = v
  104. return rv
  105. def keyfilter(predicate, d, factory=dict):
  106. """ Filter items in dictionary by key
  107. >>> iseven = lambda x: x % 2 == 0
  108. >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
  109. >>> keyfilter(iseven, d)
  110. {2: 3, 4: 5}
  111. See Also:
  112. valfilter
  113. itemfilter
  114. keymap
  115. """
  116. rv = factory()
  117. for k, v in d.items():
  118. if predicate(k):
  119. rv[k] = v
  120. return rv
  121. def itemfilter(predicate, d, factory=dict):
  122. """ Filter items in dictionary by item
  123. >>> def isvalid(item):
  124. ... k, v = item
  125. ... return k % 2 == 0 and v < 4
  126. >>> d = {1: 2, 2: 3, 3: 4, 4: 5}
  127. >>> itemfilter(isvalid, d)
  128. {2: 3}
  129. See Also:
  130. keyfilter
  131. valfilter
  132. itemmap
  133. """
  134. rv = factory()
  135. for item in d.items():
  136. if predicate(item):
  137. k, v = item
  138. rv[k] = v
  139. return rv
  140. def assoc(d, key, value, factory=dict):
  141. """ Return a new dict with new key value pair
  142. New dict has d[key] set to value. Does not modify the initial dictionary.
  143. >>> assoc({'x': 1}, 'x', 2)
  144. {'x': 2}
  145. >>> assoc({'x': 1}, 'y', 3) # doctest: +SKIP
  146. {'x': 1, 'y': 3}
  147. """
  148. d2 = factory()
  149. d2.update(d)
  150. d2[key] = value
  151. return d2
  152. def dissoc(d, *keys, **kwargs):
  153. """ Return a new dict with the given key(s) removed.
  154. New dict has d[key] deleted for each supplied key.
  155. Does not modify the initial dictionary.
  156. >>> dissoc({'x': 1, 'y': 2}, 'y')
  157. {'x': 1}
  158. >>> dissoc({'x': 1, 'y': 2}, 'y', 'x')
  159. {}
  160. >>> dissoc({'x': 1}, 'y') # Ignores missing keys
  161. {'x': 1}
  162. """
  163. factory = _get_factory(dissoc, kwargs)
  164. d2 = factory()
  165. if len(keys) < len(d) * .6:
  166. d2.update(d)
  167. for key in keys:
  168. if key in d2:
  169. del d2[key]
  170. else:
  171. remaining = set(d)
  172. remaining.difference_update(keys)
  173. for k in remaining:
  174. d2[k] = d[k]
  175. return d2
  176. def assoc_in(d, keys, value, factory=dict):
  177. """ Return a new dict with new, potentially nested, key value pair
  178. >>> purchase = {'name': 'Alice',
  179. ... 'order': {'items': ['Apple', 'Orange'],
  180. ... 'costs': [0.50, 1.25]},
  181. ... 'credit card': '5555-1234-1234-1234'}
  182. >>> assoc_in(purchase, ['order', 'costs'], [0.25, 1.00]) # doctest: +SKIP
  183. {'credit card': '5555-1234-1234-1234',
  184. 'name': 'Alice',
  185. 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}}
  186. """
  187. return update_in(d, keys, lambda x: value, value, factory)
  188. def update_in(d, keys, func, default=None, factory=dict):
  189. """ Update value in a (potentially) nested dictionary
  190. inputs:
  191. d - dictionary on which to operate
  192. keys - list or tuple giving the location of the value to be changed in d
  193. func - function to operate on that value
  194. If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the
  195. original dictionary with v replaced by func(v), but does not mutate the
  196. original dictionary.
  197. If k0 is not a key in d, update_in creates nested dictionaries to the depth
  198. specified by the keys, with the innermost value set to func(default).
  199. >>> inc = lambda x: x + 1
  200. >>> update_in({'a': 0}, ['a'], inc)
  201. {'a': 1}
  202. >>> transaction = {'name': 'Alice',
  203. ... 'purchase': {'items': ['Apple', 'Orange'],
  204. ... 'costs': [0.50, 1.25]},
  205. ... 'credit card': '5555-1234-1234-1234'}
  206. >>> update_in(transaction, ['purchase', 'costs'], sum) # doctest: +SKIP
  207. {'credit card': '5555-1234-1234-1234',
  208. 'name': 'Alice',
  209. 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}}
  210. >>> # updating a value when k0 is not in d
  211. >>> update_in({}, [1, 2, 3], str, default="bar")
  212. {1: {2: {3: 'bar'}}}
  213. >>> update_in({1: 'foo'}, [2, 3, 4], inc, 0)
  214. {1: 'foo', 2: {3: {4: 1}}}
  215. """
  216. ks = iter(keys)
  217. k = next(ks)
  218. rv = inner = factory()
  219. rv.update(d)
  220. for key in ks:
  221. if k in d:
  222. d = d[k]
  223. dtemp = factory()
  224. dtemp.update(d)
  225. else:
  226. d = dtemp = factory()
  227. inner[k] = inner = dtemp
  228. k = key
  229. if k in d:
  230. inner[k] = func(d[k])
  231. else:
  232. inner[k] = func(default)
  233. return rv
  234. def get_in(keys, coll, default=None, no_default=False):
  235. """ Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys.
  236. If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless
  237. ``no_default`` is specified, then it raises KeyError or IndexError.
  238. ``get_in`` is a generalization of ``operator.getitem`` for nested data
  239. structures such as dictionaries and lists.
  240. >>> transaction = {'name': 'Alice',
  241. ... 'purchase': {'items': ['Apple', 'Orange'],
  242. ... 'costs': [0.50, 1.25]},
  243. ... 'credit card': '5555-1234-1234-1234'}
  244. >>> get_in(['purchase', 'items', 0], transaction)
  245. 'Apple'
  246. >>> get_in(['name'], transaction)
  247. 'Alice'
  248. >>> get_in(['purchase', 'total'], transaction)
  249. >>> get_in(['purchase', 'items', 'apple'], transaction)
  250. >>> get_in(['purchase', 'items', 10], transaction)
  251. >>> get_in(['purchase', 'total'], transaction, 0)
  252. 0
  253. >>> get_in(['y'], {}, no_default=True)
  254. Traceback (most recent call last):
  255. ...
  256. KeyError: 'y'
  257. See Also:
  258. itertoolz.get
  259. operator.getitem
  260. """
  261. try:
  262. return reduce(operator.getitem, keys, coll)
  263. except (KeyError, IndexError, TypeError):
  264. if no_default:
  265. raise
  266. return default
  267. def getter(index):
  268. if isinstance(index, list):
  269. if len(index) == 1:
  270. index = index[0]
  271. return lambda x: (x[index],)
  272. elif index:
  273. return operator.itemgetter(*index)
  274. else:
  275. return lambda x: ()
  276. else:
  277. return operator.itemgetter(index)
  278. def groupby(key, seq):
  279. """ Group a collection by a key function
  280. >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
  281. >>> groupby(len, names) # doctest: +SKIP
  282. {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
  283. >>> iseven = lambda x: x % 2 == 0
  284. >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
  285. {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
  286. Non-callable keys imply grouping on a member.
  287. >>> groupby('gender', [{'name': 'Alice', 'gender': 'F'},
  288. ... {'name': 'Bob', 'gender': 'M'},
  289. ... {'name': 'Charlie', 'gender': 'M'}]) # doctest:+SKIP
  290. {'F': [{'gender': 'F', 'name': 'Alice'}],
  291. 'M': [{'gender': 'M', 'name': 'Bob'},
  292. {'gender': 'M', 'name': 'Charlie'}]}
  293. Not to be confused with ``itertools.groupby``
  294. See Also:
  295. countby
  296. """
  297. if not callable(key):
  298. key = getter(key)
  299. d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated]
  300. for item in seq:
  301. d[key(item)](item)
  302. rv = {}
  303. for k, v in d.items():
  304. rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined]
  305. return rv
  306. def first(seq):
  307. """ The first element in a sequence
  308. >>> first('ABC')
  309. 'A'
  310. """
  311. return next(iter(seq))