123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443 |
- """
- Implementation of optimized einsum.
- """
- import itertools
- import operator
- from numpy.core.multiarray import c_einsum
- from numpy.core.numeric import asanyarray, tensordot
- from numpy.core.overrides import array_function_dispatch
- __all__ = ['einsum', 'einsum_path']
- einsum_symbols = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ'
- einsum_symbols_set = set(einsum_symbols)
- def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
- """
- Computes the number of FLOPS in the contraction.
- Parameters
- ----------
- idx_contraction : iterable
- The indices involved in the contraction
- inner : bool
- Does this contraction require an inner product?
- num_terms : int
- The number of terms in a contraction
- size_dictionary : dict
- The size of each of the indices in idx_contraction
- Returns
- -------
- flop_count : int
- The total number of FLOPS required for the contraction.
- Examples
- --------
- >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
- 30
- >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
- 60
- """
- overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
- op_factor = max(1, num_terms - 1)
- if inner:
- op_factor += 1
- return overall_size * op_factor
- def _compute_size_by_dict(indices, idx_dict):
- """
- Computes the product of the elements in indices based on the dictionary
- idx_dict.
- Parameters
- ----------
- indices : iterable
- Indices to base the product on.
- idx_dict : dictionary
- Dictionary of index sizes
- Returns
- -------
- ret : int
- The resulting product.
- Examples
- --------
- >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
- 90
- """
- ret = 1
- for i in indices:
- ret *= idx_dict[i]
- return ret
- def _find_contraction(positions, input_sets, output_set):
- """
- Finds the contraction for a given set of input and output sets.
- Parameters
- ----------
- positions : iterable
- Integer positions of terms used in the contraction.
- input_sets : list
- List of sets that represent the lhs side of the einsum subscript
- output_set : set
- Set that represents the rhs side of the overall einsum subscript
- Returns
- -------
- new_result : set
- The indices of the resulting contraction
- remaining : list
- List of sets that have not been contracted, the new set is appended to
- the end of this list
- idx_removed : set
- Indices removed from the entire contraction
- idx_contraction : set
- The indices used in the current contraction
- Examples
- --------
- # A simple dot product test case
- >>> pos = (0, 1)
- >>> isets = [set('ab'), set('bc')]
- >>> oset = set('ac')
- >>> _find_contraction(pos, isets, oset)
- ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
- # A more complex case with additional terms in the contraction
- >>> pos = (0, 2)
- >>> isets = [set('abd'), set('ac'), set('bdc')]
- >>> oset = set('ac')
- >>> _find_contraction(pos, isets, oset)
- ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
- """
- idx_contract = set()
- idx_remain = output_set.copy()
- remaining = []
- for ind, value in enumerate(input_sets):
- if ind in positions:
- idx_contract |= value
- else:
- remaining.append(value)
- idx_remain |= value
- new_result = idx_remain & idx_contract
- idx_removed = (idx_contract - new_result)
- remaining.append(new_result)
- return (new_result, remaining, idx_removed, idx_contract)
- def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
- """
- Computes all possible pair contractions, sieves the results based
- on ``memory_limit`` and returns the lowest cost path. This algorithm
- scales factorial with respect to the elements in the list ``input_sets``.
- Parameters
- ----------
- input_sets : list
- List of sets that represent the lhs side of the einsum subscript
- output_set : set
- Set that represents the rhs side of the overall einsum subscript
- idx_dict : dictionary
- Dictionary of index sizes
- memory_limit : int
- The maximum number of elements in a temporary array
- Returns
- -------
- path : list
- The optimal contraction order within the memory limit constraint.
- Examples
- --------
- >>> isets = [set('abd'), set('ac'), set('bdc')]
- >>> oset = set()
- >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
- >>> _optimal_path(isets, oset, idx_sizes, 5000)
- [(0, 2), (0, 1)]
- """
- full_results = [(0, [], input_sets)]
- for iteration in range(len(input_sets) - 1):
- iter_results = []
- # Compute all unique pairs
- for curr in full_results:
- cost, positions, remaining = curr
- for con in itertools.combinations(range(len(input_sets) - iteration), 2):
- # Find the contraction
- cont = _find_contraction(con, remaining, output_set)
- new_result, new_input_sets, idx_removed, idx_contract = cont
- # Sieve the results based on memory_limit
- new_size = _compute_size_by_dict(new_result, idx_dict)
- if new_size > memory_limit:
- continue
- # Build (total_cost, positions, indices_remaining)
- total_cost = cost + _flop_count(idx_contract, idx_removed, len(con), idx_dict)
- new_pos = positions + [con]
- iter_results.append((total_cost, new_pos, new_input_sets))
- # Update combinatorial list, if we did not find anything return best
- # path + remaining contractions
- if iter_results:
- full_results = iter_results
- else:
- path = min(full_results, key=lambda x: x[0])[1]
- path += [tuple(range(len(input_sets) - iteration))]
- return path
- # If we have not found anything return single einsum contraction
- if len(full_results) == 0:
- return [tuple(range(len(input_sets)))]
- path = min(full_results, key=lambda x: x[0])[1]
- return path
- def _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost, naive_cost):
- """Compute the cost (removed size + flops) and resultant indices for
- performing the contraction specified by ``positions``.
- Parameters
- ----------
- positions : tuple of int
- The locations of the proposed tensors to contract.
- input_sets : list of sets
- The indices found on each tensors.
- output_set : set
- The output indices of the expression.
- idx_dict : dict
- Mapping of each index to its size.
- memory_limit : int
- The total allowed size for an intermediary tensor.
- path_cost : int
- The contraction cost so far.
- naive_cost : int
- The cost of the unoptimized expression.
- Returns
- -------
- cost : (int, int)
- A tuple containing the size of any indices removed, and the flop cost.
- positions : tuple of int
- The locations of the proposed tensors to contract.
- new_input_sets : list of sets
- The resulting new list of indices if this proposed contraction is performed.
- """
- # Find the contraction
- contract = _find_contraction(positions, input_sets, output_set)
- idx_result, new_input_sets, idx_removed, idx_contract = contract
- # Sieve the results based on memory_limit
- new_size = _compute_size_by_dict(idx_result, idx_dict)
- if new_size > memory_limit:
- return None
- # Build sort tuple
- old_sizes = (_compute_size_by_dict(input_sets[p], idx_dict) for p in positions)
- removed_size = sum(old_sizes) - new_size
- # NB: removed_size used to be just the size of any removed indices i.e.:
- # helpers.compute_size_by_dict(idx_removed, idx_dict)
- cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
- sort = (-removed_size, cost)
- # Sieve based on total cost as well
- if (path_cost + cost) > naive_cost:
- return None
- # Add contraction to possible choices
- return [sort, positions, new_input_sets]
- def _update_other_results(results, best):
- """Update the positions and provisional input_sets of ``results`` based on
- performing the contraction result ``best``. Remove any involving the tensors
- contracted.
- Parameters
- ----------
- results : list
- List of contraction results produced by ``_parse_possible_contraction``.
- best : list
- The best contraction of ``results`` i.e. the one that will be performed.
- Returns
- -------
- mod_results : list
- The list of modified results, updated with outcome of ``best`` contraction.
- """
- best_con = best[1]
- bx, by = best_con
- mod_results = []
- for cost, (x, y), con_sets in results:
- # Ignore results involving tensors just contracted
- if x in best_con or y in best_con:
- continue
- # Update the input_sets
- del con_sets[by - int(by > x) - int(by > y)]
- del con_sets[bx - int(bx > x) - int(bx > y)]
- con_sets.insert(-1, best[2][-1])
- # Update the position indices
- mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
- mod_results.append((cost, mod_con, con_sets))
- return mod_results
- def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
- """
- Finds the path by contracting the best pair until the input list is
- exhausted. The best pair is found by minimizing the tuple
- ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
- matrix multiplication or inner product operations, then Hadamard like
- operations, and finally outer operations. Outer products are limited by
- ``memory_limit``. This algorithm scales cubically with respect to the
- number of elements in the list ``input_sets``.
- Parameters
- ----------
- input_sets : list
- List of sets that represent the lhs side of the einsum subscript
- output_set : set
- Set that represents the rhs side of the overall einsum subscript
- idx_dict : dictionary
- Dictionary of index sizes
- memory_limit : int
- The maximum number of elements in a temporary array
- Returns
- -------
- path : list
- The greedy contraction order within the memory limit constraint.
- Examples
- --------
- >>> isets = [set('abd'), set('ac'), set('bdc')]
- >>> oset = set()
- >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
- >>> _greedy_path(isets, oset, idx_sizes, 5000)
- [(0, 2), (0, 1)]
- """
- # Handle trivial cases that leaked through
- if len(input_sets) == 1:
- return [(0,)]
- elif len(input_sets) == 2:
- return [(0, 1)]
- # Build up a naive cost
- contract = _find_contraction(range(len(input_sets)), input_sets, output_set)
- idx_result, new_input_sets, idx_removed, idx_contract = contract
- naive_cost = _flop_count(idx_contract, idx_removed, len(input_sets), idx_dict)
- # Initially iterate over all pairs
- comb_iter = itertools.combinations(range(len(input_sets)), 2)
- known_contractions = []
- path_cost = 0
- path = []
- for iteration in range(len(input_sets) - 1):
- # Iterate over all pairs on first step, only previously found pairs on subsequent steps
- for positions in comb_iter:
- # Always initially ignore outer products
- if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
- continue
- result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit, path_cost,
- naive_cost)
- if result is not None:
- known_contractions.append(result)
- # If we do not have a inner contraction, rescan pairs including outer products
- if len(known_contractions) == 0:
- # Then check the outer products
- for positions in itertools.combinations(range(len(input_sets)), 2):
- result = _parse_possible_contraction(positions, input_sets, output_set, idx_dict, memory_limit,
- path_cost, naive_cost)
- if result is not None:
- known_contractions.append(result)
- # If we still did not find any remaining contractions, default back to einsum like behavior
- if len(known_contractions) == 0:
- path.append(tuple(range(len(input_sets))))
- break
- # Sort based on first index
- best = min(known_contractions, key=lambda x: x[0])
- # Now propagate as many unused contractions as possible to next iteration
- known_contractions = _update_other_results(known_contractions, best)
- # Next iteration only compute contractions with the new tensor
- # All other contractions have been accounted for
- input_sets = best[2]
- new_tensor_pos = len(input_sets) - 1
- comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
- # Update path and total cost
- path.append(best[1])
- path_cost += best[0][1]
- return path
- def _can_dot(inputs, result, idx_removed):
- """
- Checks if we can use BLAS (np.tensordot) call and its beneficial to do so.
- Parameters
- ----------
- inputs : list of str
- Specifies the subscripts for summation.
- result : str
- Resulting summation.
- idx_removed : set
- Indices that are removed in the summation
- Returns
- -------
- type : bool
- Returns true if BLAS should and can be used, else False
- Notes
- -----
- If the operations is BLAS level 1 or 2 and is not already aligned
- we default back to einsum as the memory movement to copy is more
- costly than the operation itself.
- Examples
- --------
- # Standard GEMM operation
- >>> _can_dot(['ij', 'jk'], 'ik', set('j'))
- True
- # Can use the standard BLAS, but requires odd data movement
- >>> _can_dot(['ijj', 'jk'], 'ik', set('j'))
- False
- # DDOT where the memory is not aligned
- >>> _can_dot(['ijk', 'ikj'], '', set('ijk'))
- False
- """
- # All `dot` calls remove indices
- if len(idx_removed) == 0:
- return False
- # BLAS can only handle two operands
- if len(inputs) != 2:
- return False
- input_left, input_right = inputs
- for c in set(input_left + input_right):
- # can't deal with repeated indices on same input or more than 2 total
- nl, nr = input_left.count(c), input_right.count(c)
- if (nl > 1) or (nr > 1) or (nl + nr > 2):
- return False
- # can't do implicit summation or dimension collapse e.g.
- # "ab,bc->c" (implicitly sum over 'a')
- # "ab,ca->ca" (take diagonal of 'a')
- if nl + nr - 1 == int(c in result):
- return False
- # Build a few temporaries
- set_left = set(input_left)
- set_right = set(input_right)
- keep_left = set_left - idx_removed
- keep_right = set_right - idx_removed
- rs = len(idx_removed)
- # At this point we are a DOT, GEMV, or GEMM operation
- # Handle inner products
- # DDOT with aligned data
- if input_left == input_right:
- return True
- # DDOT without aligned data (better to use einsum)
- if set_left == set_right:
- return False
- # Handle the 4 possible (aligned) GEMV or GEMM cases
- # GEMM or GEMV no transpose
- if input_left[-rs:] == input_right[:rs]:
- return True
- # GEMM or GEMV transpose both
- if input_left[:rs] == input_right[-rs:]:
- return True
- # GEMM or GEMV transpose right
- if input_left[-rs:] == input_right[-rs:]:
- return True
- # GEMM or GEMV transpose left
- if input_left[:rs] == input_right[:rs]:
- return True
- # Einsum is faster than GEMV if we have to copy data
- if not keep_left or not keep_right:
- return False
- # We are a matrix-matrix product, but we need to copy data
- return True
- def _parse_einsum_input(operands):
- """
- A reproduction of einsum c side einsum parsing in python.
- Returns
- -------
- input_strings : str
- Parsed input strings
- output_string : str
- Parsed output string
- operands : list of array_like
- The operands to use in the numpy contraction
- Examples
- --------
- The operand list is simplified to reduce printing:
- >>> np.random.seed(123)
- >>> a = np.random.rand(4, 4)
- >>> b = np.random.rand(4, 4, 4)
- >>> _parse_einsum_input(('...a,...a->...', a, b))
- ('za,xza', 'xz', [a, b]) # may vary
- >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
- ('za,xza', 'xz', [a, b]) # may vary
- """
- if len(operands) == 0:
- raise ValueError("No input operands")
- if isinstance(operands[0], str):
- subscripts = operands[0].replace(" ", "")
- operands = [asanyarray(v) for v in operands[1:]]
- # Ensure all characters are valid
- for s in subscripts:
- if s in '.,->':
- continue
- if s not in einsum_symbols:
- raise ValueError("Character %s is not a valid symbol." % s)
- else:
- tmp_operands = list(operands)
- operand_list = []
- subscript_list = []
- for p in range(len(operands) // 2):
- operand_list.append(tmp_operands.pop(0))
- subscript_list.append(tmp_operands.pop(0))
- output_list = tmp_operands[-1] if len(tmp_operands) else None
- operands = [asanyarray(v) for v in operand_list]
- subscripts = ""
- last = len(subscript_list) - 1
- for num, sub in enumerate(subscript_list):
- for s in sub:
- if s is Ellipsis:
- subscripts += "..."
- else:
- try:
- s = operator.index(s)
- except TypeError as e:
- raise TypeError("For this input type lists must contain "
- "either int or Ellipsis") from e
- subscripts += einsum_symbols[s]
- if num != last:
- subscripts += ","
- if output_list is not None:
- subscripts += "->"
- for s in output_list:
- if s is Ellipsis:
- subscripts += "..."
- else:
- try:
- s = operator.index(s)
- except TypeError as e:
- raise TypeError("For this input type lists must contain "
- "either int or Ellipsis") from e
- subscripts += einsum_symbols[s]
- # Check for proper "->"
- if ("-" in subscripts) or (">" in subscripts):
- invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
- if invalid or (subscripts.count("->") != 1):
- raise ValueError("Subscripts can only contain one '->'.")
- # Parse ellipses
- if "." in subscripts:
- used = subscripts.replace(".", "").replace(",", "").replace("->", "")
- unused = list(einsum_symbols_set - set(used))
- ellipse_inds = "".join(unused)
- longest = 0
- if "->" in subscripts:
- input_tmp, output_sub = subscripts.split("->")
- split_subscripts = input_tmp.split(",")
- out_sub = True
- else:
- split_subscripts = subscripts.split(',')
- out_sub = False
- for num, sub in enumerate(split_subscripts):
- if "." in sub:
- if (sub.count(".") != 3) or (sub.count("...") != 1):
- raise ValueError("Invalid Ellipses.")
- # Take into account numerical values
- if operands[num].shape == ():
- ellipse_count = 0
- else:
- ellipse_count = max(operands[num].ndim, 1)
- ellipse_count -= (len(sub) - 3)
- if ellipse_count > longest:
- longest = ellipse_count
- if ellipse_count < 0:
- raise ValueError("Ellipses lengths do not match.")
- elif ellipse_count == 0:
- split_subscripts[num] = sub.replace('...', '')
- else:
- rep_inds = ellipse_inds[-ellipse_count:]
- split_subscripts[num] = sub.replace('...', rep_inds)
- subscripts = ",".join(split_subscripts)
- if longest == 0:
- out_ellipse = ""
- else:
- out_ellipse = ellipse_inds[-longest:]
- if out_sub:
- subscripts += "->" + output_sub.replace("...", out_ellipse)
- else:
- # Special care for outputless ellipses
- output_subscript = ""
- tmp_subscripts = subscripts.replace(",", "")
- for s in sorted(set(tmp_subscripts)):
- if s not in (einsum_symbols):
- raise ValueError("Character %s is not a valid symbol." % s)
- if tmp_subscripts.count(s) == 1:
- output_subscript += s
- normal_inds = ''.join(sorted(set(output_subscript) -
- set(out_ellipse)))
- subscripts += "->" + out_ellipse + normal_inds
- # Build output string if does not exist
- if "->" in subscripts:
- input_subscripts, output_subscript = subscripts.split("->")
- else:
- input_subscripts = subscripts
- # Build output subscripts
- tmp_subscripts = subscripts.replace(",", "")
- output_subscript = ""
- for s in sorted(set(tmp_subscripts)):
- if s not in einsum_symbols:
- raise ValueError("Character %s is not a valid symbol." % s)
- if tmp_subscripts.count(s) == 1:
- output_subscript += s
- # Make sure output subscripts are in the input
- for char in output_subscript:
- if char not in input_subscripts:
- raise ValueError("Output character %s did not appear in the input"
- % char)
- # Make sure number operands is equivalent to the number of terms
- if len(input_subscripts.split(',')) != len(operands):
- raise ValueError("Number of einsum subscripts must be equal to the "
- "number of operands.")
- return (input_subscripts, output_subscript, operands)
- def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None):
- # NOTE: technically, we should only dispatch on array-like arguments, not
- # subscripts (given as strings). But separating operands into
- # arrays/subscripts is a little tricky/slow (given einsum's two supported
- # signatures), so as a practical shortcut we dispatch on everything.
- # Strings will be ignored for dispatching since they don't define
- # __array_function__.
- return operands
- @array_function_dispatch(_einsum_path_dispatcher, module='numpy')
- def einsum_path(*operands, optimize='greedy', einsum_call=False):
- """
- einsum_path(subscripts, *operands, optimize='greedy')
- Evaluates the lowest cost contraction order for an einsum expression by
- considering the creation of intermediate arrays.
- Parameters
- ----------
- subscripts : str
- Specifies the subscripts for summation.
- *operands : list of array_like
- These are the arrays for the operation.
- optimize : {bool, list, tuple, 'greedy', 'optimal'}
- Choose the type of path. If a tuple is provided, the second argument is
- assumed to be the maximum intermediate size created. If only a single
- argument is provided the largest input or output array size is used
- as a maximum intermediate size.
- * if a list is given that starts with ``einsum_path``, uses this as the
- contraction path
- * if False no optimization is taken
- * if True defaults to the 'greedy' algorithm
- * 'optimal' An algorithm that combinatorially explores all possible
- ways of contracting the listed tensors and choosest the least costly
- path. Scales exponentially with the number of terms in the
- contraction.
- * 'greedy' An algorithm that chooses the best pair contraction
- at each step. Effectively, this algorithm searches the largest inner,
- Hadamard, and then outer products at each step. Scales cubically with
- the number of terms in the contraction. Equivalent to the 'optimal'
- path for most contractions.
- Default is 'greedy'.
- Returns
- -------
- path : list of tuples
- A list representation of the einsum path.
- string_repr : str
- A printable representation of the einsum path.
- Notes
- -----
- The resulting path indicates which terms of the input contraction should be
- contracted first, the result of this contraction is then appended to the
- end of the contraction list. This list can then be iterated over until all
- intermediate contractions are complete.
- See Also
- --------
- einsum, linalg.multi_dot
- Examples
- --------
- We can begin with a chain dot example. In this case, it is optimal to
- contract the ``b`` and ``c`` tensors first as represented by the first
- element of the path ``(1, 2)``. The resulting tensor is added to the end
- of the contraction and the remaining contraction ``(0, 1)`` is then
- completed.
- >>> np.random.seed(123)
- >>> a = np.random.rand(2, 2)
- >>> b = np.random.rand(2, 5)
- >>> c = np.random.rand(5, 2)
- >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
- >>> print(path_info[0])
- ['einsum_path', (1, 2), (0, 1)]
- >>> print(path_info[1])
- Complete contraction: ij,jk,kl->il # may vary
- Naive scaling: 4
- Optimized scaling: 3
- Naive FLOP count: 1.600e+02
- Optimized FLOP count: 5.600e+01
- Theoretical speedup: 2.857
- Largest intermediate: 4.000e+00 elements
- -------------------------------------------------------------------------
- scaling current remaining
- -------------------------------------------------------------------------
- 3 kl,jk->jl ij,jl->il
- 3 jl,ij->il il->il
- A more complex index transformation example.
- >>> I = np.random.rand(10, 10, 10, 10)
- >>> C = np.random.rand(10, 10)
- >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
- ... optimize='greedy')
- >>> print(path_info[0])
- ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
- >>> print(path_info[1])
- Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
- Naive scaling: 8
- Optimized scaling: 5
- Naive FLOP count: 8.000e+08
- Optimized FLOP count: 8.000e+05
- Theoretical speedup: 1000.000
- Largest intermediate: 1.000e+04 elements
- --------------------------------------------------------------------------
- scaling current remaining
- --------------------------------------------------------------------------
- 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
- 5 bcde,fb->cdef gc,hd,cdef->efgh
- 5 cdef,gc->defg hd,defg->efgh
- 5 defg,hd->efgh efgh->efgh
- """
- # Figure out what the path really is
- path_type = optimize
- if path_type is True:
- path_type = 'greedy'
- if path_type is None:
- path_type = False
- explicit_einsum_path = False
- memory_limit = None
- # No optimization or a named path algorithm
- if (path_type is False) or isinstance(path_type, str):
- pass
- # Given an explicit path
- elif len(path_type) and (path_type[0] == 'einsum_path'):
- explicit_einsum_path = True
- # Path tuple with memory limit
- elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
- isinstance(path_type[1], (int, float))):
- memory_limit = int(path_type[1])
- path_type = path_type[0]
- else:
- raise TypeError("Did not understand the path: %s" % str(path_type))
- # Hidden option, only einsum should call this
- einsum_call_arg = einsum_call
- # Python side parsing
- input_subscripts, output_subscript, operands = _parse_einsum_input(operands)
- # Build a few useful list and sets
- input_list = input_subscripts.split(',')
- input_sets = [set(x) for x in input_list]
- output_set = set(output_subscript)
- indices = set(input_subscripts.replace(',', ''))
- # Get length of each unique dimension and ensure all dimensions are correct
- dimension_dict = {}
- broadcast_indices = [[] for x in range(len(input_list))]
- for tnum, term in enumerate(input_list):
- sh = operands[tnum].shape
- if len(sh) != len(term):
- raise ValueError("Einstein sum subscript %s does not contain the "
- "correct number of indices for operand %d."
- % (input_subscripts[tnum], tnum))
- for cnum, char in enumerate(term):
- dim = sh[cnum]
- # Build out broadcast indices
- if dim == 1:
- broadcast_indices[tnum].append(char)
- if char in dimension_dict.keys():
- # For broadcasting cases we always want the largest dim size
- if dimension_dict[char] == 1:
- dimension_dict[char] = dim
- elif dim not in (1, dimension_dict[char]):
- raise ValueError("Size of label '%s' for operand %d (%d) "
- "does not match previous terms (%d)."
- % (char, tnum, dimension_dict[char], dim))
- else:
- dimension_dict[char] = dim
- # Convert broadcast inds to sets
- broadcast_indices = [set(x) for x in broadcast_indices]
- # Compute size of each input array plus the output array
- size_list = [_compute_size_by_dict(term, dimension_dict)
- for term in input_list + [output_subscript]]
- max_size = max(size_list)
- if memory_limit is None:
- memory_arg = max_size
- else:
- memory_arg = memory_limit
- # Compute naive cost
- # This isn't quite right, need to look into exactly how einsum does this
- inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
- naive_cost = _flop_count(indices, inner_product, len(input_list), dimension_dict)
- # Compute the path
- if explicit_einsum_path:
- path = path_type[1:]
- elif (
- (path_type is False)
- or (len(input_list) in [1, 2])
- or (indices == output_set)
- ):
- # Nothing to be optimized, leave it to einsum
- path = [tuple(range(len(input_list)))]
- elif path_type == "greedy":
- path = _greedy_path(input_sets, output_set, dimension_dict, memory_arg)
- elif path_type == "optimal":
- path = _optimal_path(input_sets, output_set, dimension_dict, memory_arg)
- else:
- raise KeyError("Path name %s not found", path_type)
- cost_list, scale_list, size_list, contraction_list = [], [], [], []
- # Build contraction tuple (positions, gemm, einsum_str, remaining)
- for cnum, contract_inds in enumerate(path):
- # Make sure we remove inds from right to left
- contract_inds = tuple(sorted(list(contract_inds), reverse=True))
- contract = _find_contraction(contract_inds, input_sets, output_set)
- out_inds, input_sets, idx_removed, idx_contract = contract
- cost = _flop_count(idx_contract, idx_removed, len(contract_inds), dimension_dict)
- cost_list.append(cost)
- scale_list.append(len(idx_contract))
- size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
- bcast = set()
- tmp_inputs = []
- for x in contract_inds:
- tmp_inputs.append(input_list.pop(x))
- bcast |= broadcast_indices.pop(x)
- new_bcast_inds = bcast - idx_removed
- # If we're broadcasting, nix blas
- if not len(idx_removed & bcast):
- do_blas = _can_dot(tmp_inputs, out_inds, idx_removed)
- else:
- do_blas = False
- # Last contraction
- if (cnum - len(path)) == -1:
- idx_result = output_subscript
- else:
- sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
- idx_result = "".join([x[1] for x in sorted(sort_result)])
- input_list.append(idx_result)
- broadcast_indices.append(new_bcast_inds)
- einsum_str = ",".join(tmp_inputs) + "->" + idx_result
- contraction = (contract_inds, idx_removed, einsum_str, input_list[:], do_blas)
- contraction_list.append(contraction)
- opt_cost = sum(cost_list) + 1
- if len(input_list) != 1:
- # Explicit "einsum_path" is usually trusted, but we detect this kind of
- # mistake in order to prevent from returning an intermediate value.
- raise RuntimeError(
- "Invalid einsum_path is specified: {} more operands has to be "
- "contracted.".format(len(input_list) - 1))
- if einsum_call_arg:
- return (operands, contraction_list)
- # Return the path along with a nice string representation
- overall_contraction = input_subscripts + "->" + output_subscript
- header = ("scaling", "current", "remaining")
- speedup = naive_cost / opt_cost
- max_i = max(size_list)
- path_print = " Complete contraction: %s\n" % overall_contraction
- path_print += " Naive scaling: %d\n" % len(indices)
- path_print += " Optimized scaling: %d\n" % max(scale_list)
- path_print += " Naive FLOP count: %.3e\n" % naive_cost
- path_print += " Optimized FLOP count: %.3e\n" % opt_cost
- path_print += " Theoretical speedup: %3.3f\n" % speedup
- path_print += " Largest intermediate: %.3e elements\n" % max_i
- path_print += "-" * 74 + "\n"
- path_print += "%6s %24s %40s\n" % header
- path_print += "-" * 74
- for n, contraction in enumerate(contraction_list):
- inds, idx_rm, einsum_str, remaining, blas = contraction
- remaining_str = ",".join(remaining) + "->" + output_subscript
- path_run = (scale_list[n], einsum_str, remaining_str)
- path_print += "\n%4d %24s %40s" % path_run
- path = ['einsum_path'] + path
- return (path, path_print)
- def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs):
- # Arguably we dispatch on more arguments than we really should; see note in
- # _einsum_path_dispatcher for why.
- yield from operands
- yield out
- # Rewrite einsum to handle different cases
- @array_function_dispatch(_einsum_dispatcher, module='numpy')
- def einsum(*operands, out=None, optimize=False, **kwargs):
- """
- einsum(subscripts, *operands, out=None, dtype=None, order='K',
- casting='safe', optimize=False)
- Evaluates the Einstein summation convention on the operands.
- Using the Einstein summation convention, many common multi-dimensional,
- linear algebraic array operations can be represented in a simple fashion.
- In *implicit* mode `einsum` computes these values.
- In *explicit* mode, `einsum` provides further flexibility to compute
- other array operations that might not be considered classical Einstein
- summation operations, by disabling, or forcing summation over specified
- subscript labels.
- See the notes and examples for clarification.
- Parameters
- ----------
- subscripts : str
- Specifies the subscripts for summation as comma separated list of
- subscript labels. An implicit (classical Einstein summation)
- calculation is performed unless the explicit indicator '->' is
- included as well as subscript labels of the precise output form.
- operands : list of array_like
- These are the arrays for the operation.
- out : ndarray, optional
- If provided, the calculation is done into this array.
- dtype : {data-type, None}, optional
- If provided, forces the calculation to use the data type specified.
- Note that you may have to also give a more liberal `casting`
- parameter to allow the conversions. Default is None.
- order : {'C', 'F', 'A', 'K'}, optional
- Controls the memory layout of the output. 'C' means it should
- be C contiguous. 'F' means it should be Fortran contiguous,
- 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
- 'K' means it should be as close to the layout as the inputs as
- is possible, including arbitrarily permuted axes.
- Default is 'K'.
- casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
- Controls what kind of data casting may occur. Setting this to
- 'unsafe' is not recommended, as it can adversely affect accumulations.
- * 'no' means the data types should not be cast at all.
- * 'equiv' means only byte-order changes are allowed.
- * 'safe' means only casts which can preserve values are allowed.
- * 'same_kind' means only safe casts or casts within a kind,
- like float64 to float32, are allowed.
- * 'unsafe' means any data conversions may be done.
- Default is 'safe'.
- optimize : {False, True, 'greedy', 'optimal'}, optional
- Controls if intermediate optimization should occur. No optimization
- will occur if False and True will default to the 'greedy' algorithm.
- Also accepts an explicit contraction list from the ``np.einsum_path``
- function. See ``np.einsum_path`` for more details. Defaults to False.
- Returns
- -------
- output : ndarray
- The calculation based on the Einstein summation convention.
- See Also
- --------
- einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
- einops :
- similar verbose interface is provided by
- `einops <https://github.com/arogozhnikov/einops>`_ package to cover
- additional operations: transpose, reshape/flatten, repeat/tile,
- squeeze/unsqueeze and reductions.
- opt_einsum :
- `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_
- optimizes contraction order for einsum-like expressions
- in backend-agnostic manner.
- Notes
- -----
- .. versionadded:: 1.6.0
- The Einstein summation convention can be used to compute
- many multi-dimensional, linear algebraic array operations. `einsum`
- provides a succinct way of representing these.
- A non-exhaustive list of these operations,
- which can be computed by `einsum`, is shown below along with examples:
- * Trace of an array, :py:func:`numpy.trace`.
- * Return a diagonal, :py:func:`numpy.diag`.
- * Array axis summations, :py:func:`numpy.sum`.
- * Transpositions and permutations, :py:func:`numpy.transpose`.
- * Matrix multiplication and dot product, :py:func:`numpy.matmul` :py:func:`numpy.dot`.
- * Vector inner and outer products, :py:func:`numpy.inner` :py:func:`numpy.outer`.
- * Broadcasting, element-wise and scalar multiplication, :py:func:`numpy.multiply`.
- * Tensor contractions, :py:func:`numpy.tensordot`.
- * Chained array operations, in efficient calculation order, :py:func:`numpy.einsum_path`.
- The subscripts string is a comma-separated list of subscript labels,
- where each label refers to a dimension of the corresponding operand.
- Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
- is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
- appears only once, it is not summed, so ``np.einsum('i', a)`` produces a
- view of ``a`` with no changes. A further example ``np.einsum('ij,jk', a, b)``
- describes traditional matrix multiplication and is equivalent to
- :py:func:`np.matmul(a,b) <numpy.matmul>`. Repeated subscript labels in one
- operand take the diagonal. For example, ``np.einsum('ii', a)`` is equivalent
- to :py:func:`np.trace(a) <numpy.trace>`.
- In *implicit mode*, the chosen subscripts are important
- since the axes of the output are reordered alphabetically. This
- means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
- ``np.einsum('ji', a)`` takes its transpose. Additionally,
- ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
- ``np.einsum('ij,jh', a, b)`` returns the transpose of the
- multiplication since subscript 'h' precedes subscript 'i'.
- In *explicit mode* the output can be directly controlled by
- specifying output subscript labels. This requires the
- identifier '->' as well as the list of output subscript labels.
- This feature increases the flexibility of the function since
- summing can be disabled or forced when required. The call
- ``np.einsum('i->', a)`` is like :py:func:`np.sum(a, axis=-1) <numpy.sum>`,
- and ``np.einsum('ii->i', a)`` is like :py:func:`np.diag(a) <numpy.diag>`.
- The difference is that `einsum` does not allow broadcasting by default.
- Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
- order of the output subscript labels and therefore returns matrix
- multiplication, unlike the example above in implicit mode.
- To enable and control broadcasting, use an ellipsis. Default
- NumPy-style broadcasting is done by adding an ellipsis
- to the left of each term, like ``np.einsum('...ii->...i', a)``.
- To take the trace along the first and last axes,
- you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
- product with the left-most indices instead of rightmost, one can do
- ``np.einsum('ij...,jk...->ik...', a, b)``.
- When there is only one operand, no axes are summed, and no output
- parameter is provided, a view into the operand is returned instead
- of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
- produces a view (changed in version 1.10.0).
- `einsum` also provides an alternative way to provide the subscripts
- and operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
- If the output shape is not provided in this format `einsum` will be
- calculated in implicit mode, otherwise it will be performed explicitly.
- The examples below have corresponding `einsum` calls with the two
- parameter methods.
- .. versionadded:: 1.10.0
- Views returned from einsum are now writeable whenever the input array
- is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
- have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
- and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
- of a 2D array.
- .. versionadded:: 1.12.0
- Added the ``optimize`` argument which will optimize the contraction order
- of an einsum expression. For a contraction with three or more operands this
- can greatly increase the computational efficiency at the cost of a larger
- memory footprint during computation.
- Typically a 'greedy' algorithm is applied which empirical tests have shown
- returns the optimal path in the majority of cases. In some cases 'optimal'
- will return the superlative path through a more expensive, exhaustive search.
- For iterative calculations it may be advisable to calculate the optimal path
- once and reuse that path by supplying it as an argument. An example is given
- below.
- See :py:func:`numpy.einsum_path` for more details.
- Examples
- --------
- >>> a = np.arange(25).reshape(5,5)
- >>> b = np.arange(5)
- >>> c = np.arange(6).reshape(2,3)
- Trace of a matrix:
- >>> np.einsum('ii', a)
- 60
- >>> np.einsum(a, [0,0])
- 60
- >>> np.trace(a)
- 60
- Extract the diagonal (requires explicit form):
- >>> np.einsum('ii->i', a)
- array([ 0, 6, 12, 18, 24])
- >>> np.einsum(a, [0,0], [0])
- array([ 0, 6, 12, 18, 24])
- >>> np.diag(a)
- array([ 0, 6, 12, 18, 24])
- Sum over an axis (requires explicit form):
- >>> np.einsum('ij->i', a)
- array([ 10, 35, 60, 85, 110])
- >>> np.einsum(a, [0,1], [0])
- array([ 10, 35, 60, 85, 110])
- >>> np.sum(a, axis=1)
- array([ 10, 35, 60, 85, 110])
- For higher dimensional arrays summing a single axis can be done with ellipsis:
- >>> np.einsum('...j->...', a)
- array([ 10, 35, 60, 85, 110])
- >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
- array([ 10, 35, 60, 85, 110])
- Compute a matrix transpose, or reorder any number of axes:
- >>> np.einsum('ji', c)
- array([[0, 3],
- [1, 4],
- [2, 5]])
- >>> np.einsum('ij->ji', c)
- array([[0, 3],
- [1, 4],
- [2, 5]])
- >>> np.einsum(c, [1,0])
- array([[0, 3],
- [1, 4],
- [2, 5]])
- >>> np.transpose(c)
- array([[0, 3],
- [1, 4],
- [2, 5]])
- Vector inner products:
- >>> np.einsum('i,i', b, b)
- 30
- >>> np.einsum(b, [0], b, [0])
- 30
- >>> np.inner(b,b)
- 30
- Matrix vector multiplication:
- >>> np.einsum('ij,j', a, b)
- array([ 30, 80, 130, 180, 230])
- >>> np.einsum(a, [0,1], b, [1])
- array([ 30, 80, 130, 180, 230])
- >>> np.dot(a, b)
- array([ 30, 80, 130, 180, 230])
- >>> np.einsum('...j,j', a, b)
- array([ 30, 80, 130, 180, 230])
- Broadcasting and scalar multiplication:
- >>> np.einsum('..., ...', 3, c)
- array([[ 0, 3, 6],
- [ 9, 12, 15]])
- >>> np.einsum(',ij', 3, c)
- array([[ 0, 3, 6],
- [ 9, 12, 15]])
- >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
- array([[ 0, 3, 6],
- [ 9, 12, 15]])
- >>> np.multiply(3, c)
- array([[ 0, 3, 6],
- [ 9, 12, 15]])
- Vector outer product:
- >>> np.einsum('i,j', np.arange(2)+1, b)
- array([[0, 1, 2, 3, 4],
- [0, 2, 4, 6, 8]])
- >>> np.einsum(np.arange(2)+1, [0], b, [1])
- array([[0, 1, 2, 3, 4],
- [0, 2, 4, 6, 8]])
- >>> np.outer(np.arange(2)+1, b)
- array([[0, 1, 2, 3, 4],
- [0, 2, 4, 6, 8]])
- Tensor contraction:
- >>> a = np.arange(60.).reshape(3,4,5)
- >>> b = np.arange(24.).reshape(4,3,2)
- >>> np.einsum('ijk,jil->kl', a, b)
- array([[4400., 4730.],
- [4532., 4874.],
- [4664., 5018.],
- [4796., 5162.],
- [4928., 5306.]])
- >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
- array([[4400., 4730.],
- [4532., 4874.],
- [4664., 5018.],
- [4796., 5162.],
- [4928., 5306.]])
- >>> np.tensordot(a,b, axes=([1,0],[0,1]))
- array([[4400., 4730.],
- [4532., 4874.],
- [4664., 5018.],
- [4796., 5162.],
- [4928., 5306.]])
- Writeable returned arrays (since version 1.10.0):
- >>> a = np.zeros((3, 3))
- >>> np.einsum('ii->i', a)[:] = 1
- >>> a
- array([[1., 0., 0.],
- [0., 1., 0.],
- [0., 0., 1.]])
- Example of ellipsis use:
- >>> a = np.arange(6).reshape((3,2))
- >>> b = np.arange(12).reshape((4,3))
- >>> np.einsum('ki,jk->ij', a, b)
- array([[10, 28, 46, 64],
- [13, 40, 67, 94]])
- >>> np.einsum('ki,...k->i...', a, b)
- array([[10, 28, 46, 64],
- [13, 40, 67, 94]])
- >>> np.einsum('k...,jk', a, b)
- array([[10, 28, 46, 64],
- [13, 40, 67, 94]])
- Chained array operations. For more complicated contractions, speed ups
- might be achieved by repeatedly computing a 'greedy' path or pre-computing the
- 'optimal' path and repeatedly applying it, using an
- `einsum_path` insertion (since version 1.12.0). Performance improvements can be
- particularly significant with larger arrays:
- >>> a = np.ones(64).reshape(2,4,8)
- Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
- >>> for iteration in range(500):
- ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
- Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
- >>> for iteration in range(500):
- ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')
- Greedy `einsum` (faster optimal path approximation): ~160ms
- >>> for iteration in range(500):
- ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
- Optimal `einsum` (best usage pattern in some use cases): ~110ms
- >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='optimal')[0]
- >>> for iteration in range(500):
- ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
- """
- # Special handling if out is specified
- specified_out = out is not None
- # If no optimization, run pure einsum
- if optimize is False:
- if specified_out:
- kwargs['out'] = out
- return c_einsum(*operands, **kwargs)
- # Check the kwargs to avoid a more cryptic error later, without having to
- # repeat default values here
- valid_einsum_kwargs = ['dtype', 'order', 'casting']
- unknown_kwargs = [k for (k, v) in kwargs.items() if
- k not in valid_einsum_kwargs]
- if len(unknown_kwargs):
- raise TypeError("Did not understand the following kwargs: %s"
- % unknown_kwargs)
- # Build the contraction list and operand
- operands, contraction_list = einsum_path(*operands, optimize=optimize,
- einsum_call=True)
- # Handle order kwarg for output array, c_einsum allows mixed case
- output_order = kwargs.pop('order', 'K')
- if output_order.upper() == 'A':
- if all(arr.flags.f_contiguous for arr in operands):
- output_order = 'F'
- else:
- output_order = 'C'
- # Start contraction loop
- for num, contraction in enumerate(contraction_list):
- inds, idx_rm, einsum_str, remaining, blas = contraction
- tmp_operands = [operands.pop(x) for x in inds]
- # Do we need to deal with the output?
- handle_out = specified_out and ((num + 1) == len(contraction_list))
- # Call tensordot if still possible
- if blas:
- # Checks have already been handled
- input_str, results_index = einsum_str.split('->')
- input_left, input_right = input_str.split(',')
- tensor_result = input_left + input_right
- for s in idx_rm:
- tensor_result = tensor_result.replace(s, "")
- # Find indices to contract over
- left_pos, right_pos = [], []
- for s in sorted(idx_rm):
- left_pos.append(input_left.find(s))
- right_pos.append(input_right.find(s))
- # Contract!
- new_view = tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)))
- # Build a new view if needed
- if (tensor_result != results_index) or handle_out:
- if handle_out:
- kwargs["out"] = out
- new_view = c_einsum(tensor_result + '->' + results_index, new_view, **kwargs)
- # Call einsum
- else:
- # If out was specified
- if handle_out:
- kwargs["out"] = out
- # Do the contraction
- new_view = c_einsum(einsum_str, *tmp_operands, **kwargs)
- # Append new items and dereference what we can
- operands.append(new_view)
- del tmp_operands, new_view
- if specified_out:
- return out
- else:
- return asanyarray(operands[0], order=output_order)
|