123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123 |
- import bisect
- from collections import defaultdict
- from sympy.combinatorics import Permutation
- from sympy.core.containers import Tuple
- from sympy.core.numbers import Integer
- def _get_mapping_from_subranks(subranks):
- mapping = {}
- counter = 0
- for i, rank in enumerate(subranks):
- for j in range(rank):
- mapping[counter] = (i, j)
- counter += 1
- return mapping
- def _get_contraction_links(args, subranks, *contraction_indices):
- mapping = _get_mapping_from_subranks(subranks)
- contraction_tuples = [[mapping[j] for j in i] for i in contraction_indices]
- dlinks = defaultdict(dict)
- for links in contraction_tuples:
- if len(links) == 2:
- (arg1, pos1), (arg2, pos2) = links
- dlinks[arg1][pos1] = (arg2, pos2)
- dlinks[arg2][pos2] = (arg1, pos1)
- continue
- return args, dict(dlinks)
- def _sort_contraction_indices(pairing_indices):
- pairing_indices = [Tuple(*sorted(i)) for i in pairing_indices]
- pairing_indices.sort(key=lambda x: min(x))
- return pairing_indices
- def _get_diagonal_indices(flattened_indices):
- axes_contraction = defaultdict(list)
- for i, ind in enumerate(flattened_indices):
- if isinstance(ind, (int, Integer)):
- # If the indices is a number, there can be no diagonal operation:
- continue
- axes_contraction[ind].append(i)
- axes_contraction = {k: v for k, v in axes_contraction.items() if len(v) > 1}
- # Put the diagonalized indices at the end:
- ret_indices = [i for i in flattened_indices if i not in axes_contraction]
- diag_indices = list(axes_contraction)
- diag_indices.sort(key=lambda x: flattened_indices.index(x))
- diagonal_indices = [tuple(axes_contraction[i]) for i in diag_indices]
- ret_indices += diag_indices
- ret_indices = tuple(ret_indices)
- return diagonal_indices, ret_indices
- def _get_argindex(subindices, ind):
- for i, sind in enumerate(subindices):
- if ind == sind:
- return i
- if isinstance(sind, (set, frozenset)) and ind in sind:
- return i
- raise IndexError("%s not found in %s" % (ind, subindices))
- def _apply_recursively_over_nested_lists(func, arr):
- if isinstance(arr, (tuple, list, Tuple)):
- return tuple(_apply_recursively_over_nested_lists(func, i) for i in arr)
- elif isinstance(arr, Tuple):
- return Tuple.fromiter(_apply_recursively_over_nested_lists(func, i) for i in arr)
- else:
- return func(arr)
- def _build_push_indices_up_func_transformation(flattened_contraction_indices):
- shifts = {0: 0}
- i = 0
- cumulative = 0
- while i < len(flattened_contraction_indices):
- j = 1
- while i+j < len(flattened_contraction_indices):
- if flattened_contraction_indices[i] + j != flattened_contraction_indices[i+j]:
- break
- j += 1
- cumulative += j
- shifts[flattened_contraction_indices[i]] = cumulative
- i += j
- shift_keys = sorted(shifts.keys())
- def func(idx):
- return shifts[shift_keys[bisect.bisect_right(shift_keys, idx)-1]]
- def transform(j):
- if j in flattened_contraction_indices:
- return None
- else:
- return j - func(j)
- return transform
- def _build_push_indices_down_func_transformation(flattened_contraction_indices):
- N = flattened_contraction_indices[-1]+2
- shifts = [i for i in range(N) if i not in flattened_contraction_indices]
- def transform(j):
- if j < len(shifts):
- return shifts[j]
- else:
- return j + shifts[-1] - len(shifts) + 1
- return transform
- def _apply_permutation_to_list(perm: Permutation, target_list: list):
- """
- Permute a list according to the given permutation.
- """
- new_list = [None for i in range(perm.size)]
- for i, e in enumerate(target_list):
- new_list[perm(i)] = e
- return new_list
|