utils.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. import bisect
  2. from collections import defaultdict
  3. from sympy.combinatorics import Permutation
  4. from sympy.core.containers import Tuple
  5. from sympy.core.numbers import Integer
  6. def _get_mapping_from_subranks(subranks):
  7. mapping = {}
  8. counter = 0
  9. for i, rank in enumerate(subranks):
  10. for j in range(rank):
  11. mapping[counter] = (i, j)
  12. counter += 1
  13. return mapping
  14. def _get_contraction_links(args, subranks, *contraction_indices):
  15. mapping = _get_mapping_from_subranks(subranks)
  16. contraction_tuples = [[mapping[j] for j in i] for i in contraction_indices]
  17. dlinks = defaultdict(dict)
  18. for links in contraction_tuples:
  19. if len(links) == 2:
  20. (arg1, pos1), (arg2, pos2) = links
  21. dlinks[arg1][pos1] = (arg2, pos2)
  22. dlinks[arg2][pos2] = (arg1, pos1)
  23. continue
  24. return args, dict(dlinks)
  25. def _sort_contraction_indices(pairing_indices):
  26. pairing_indices = [Tuple(*sorted(i)) for i in pairing_indices]
  27. pairing_indices.sort(key=lambda x: min(x))
  28. return pairing_indices
  29. def _get_diagonal_indices(flattened_indices):
  30. axes_contraction = defaultdict(list)
  31. for i, ind in enumerate(flattened_indices):
  32. if isinstance(ind, (int, Integer)):
  33. # If the indices is a number, there can be no diagonal operation:
  34. continue
  35. axes_contraction[ind].append(i)
  36. axes_contraction = {k: v for k, v in axes_contraction.items() if len(v) > 1}
  37. # Put the diagonalized indices at the end:
  38. ret_indices = [i for i in flattened_indices if i not in axes_contraction]
  39. diag_indices = list(axes_contraction)
  40. diag_indices.sort(key=lambda x: flattened_indices.index(x))
  41. diagonal_indices = [tuple(axes_contraction[i]) for i in diag_indices]
  42. ret_indices += diag_indices
  43. ret_indices = tuple(ret_indices)
  44. return diagonal_indices, ret_indices
  45. def _get_argindex(subindices, ind):
  46. for i, sind in enumerate(subindices):
  47. if ind == sind:
  48. return i
  49. if isinstance(sind, (set, frozenset)) and ind in sind:
  50. return i
  51. raise IndexError("%s not found in %s" % (ind, subindices))
  52. def _apply_recursively_over_nested_lists(func, arr):
  53. if isinstance(arr, (tuple, list, Tuple)):
  54. return tuple(_apply_recursively_over_nested_lists(func, i) for i in arr)
  55. elif isinstance(arr, Tuple):
  56. return Tuple.fromiter(_apply_recursively_over_nested_lists(func, i) for i in arr)
  57. else:
  58. return func(arr)
  59. def _build_push_indices_up_func_transformation(flattened_contraction_indices):
  60. shifts = {0: 0}
  61. i = 0
  62. cumulative = 0
  63. while i < len(flattened_contraction_indices):
  64. j = 1
  65. while i+j < len(flattened_contraction_indices):
  66. if flattened_contraction_indices[i] + j != flattened_contraction_indices[i+j]:
  67. break
  68. j += 1
  69. cumulative += j
  70. shifts[flattened_contraction_indices[i]] = cumulative
  71. i += j
  72. shift_keys = sorted(shifts.keys())
  73. def func(idx):
  74. return shifts[shift_keys[bisect.bisect_right(shift_keys, idx)-1]]
  75. def transform(j):
  76. if j in flattened_contraction_indices:
  77. return None
  78. else:
  79. return j - func(j)
  80. return transform
  81. def _build_push_indices_down_func_transformation(flattened_contraction_indices):
  82. N = flattened_contraction_indices[-1]+2
  83. shifts = [i for i in range(N) if i not in flattened_contraction_indices]
  84. def transform(j):
  85. if j < len(shifts):
  86. return shifts[j]
  87. else:
  88. return j + shifts[-1] - len(shifts) + 1
  89. return transform
  90. def _apply_permutation_to_list(perm: Permutation, target_list: list):
  91. """
  92. Permute a list according to the given permutation.
  93. """
  94. new_list = [None for i in range(perm.size)]
  95. for i, e in enumerate(target_list):
  96. new_list[perm(i)] = e
  97. return new_list