tensor.py 161 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863
  1. """
  2. This module defines tensors with abstract index notation.
  3. The abstract index notation has been first formalized by Penrose.
  4. Tensor indices are formal objects, with a tensor type; there is no
  5. notion of index range, it is only possible to assign the dimension,
  6. used to trace the Kronecker delta; the dimension can be a Symbol.
  7. The Einstein summation convention is used.
  8. The covariant indices are indicated with a minus sign in front of the index.
  9. For instance the tensor ``t = p(a)*A(b,c)*q(-c)`` has the index ``c``
  10. contracted.
  11. A tensor expression ``t`` can be called; called with its
  12. indices in sorted order it is equal to itself:
  13. in the above example ``t(a, b) == t``;
  14. one can call ``t`` with different indices; ``t(c, d) == p(c)*A(d,a)*q(-a)``.
  15. The contracted indices are dummy indices, internally they have no name,
  16. the indices being represented by a graph-like structure.
  17. Tensors are put in canonical form using ``canon_bp``, which uses
  18. the Butler-Portugal algorithm for canonicalization using the monoterm
  19. symmetries of the tensors.
  20. If there is a (anti)symmetric metric, the indices can be raised and
  21. lowered when the tensor is put in canonical form.
  22. """
  23. from __future__ import annotations
  24. from typing import Any
  25. from functools import reduce
  26. from math import prod
  27. from abc import abstractmethod, ABC
  28. from collections import defaultdict
  29. import operator
  30. import itertools
  31. from sympy.core.numbers import (Integer, Rational)
  32. from sympy.combinatorics import Permutation
  33. from sympy.combinatorics.tensor_can import get_symmetric_group_sgs, \
  34. bsgs_direct_product, canonicalize, riemann_bsgs
  35. from sympy.core import Basic, Expr, sympify, Add, Mul, S
  36. from sympy.core.containers import Tuple, Dict
  37. from sympy.core.sorting import default_sort_key
  38. from sympy.core.symbol import Symbol, symbols
  39. from sympy.core.sympify import CantSympify, _sympify
  40. from sympy.core.operations import AssocOp
  41. from sympy.external.gmpy import SYMPY_INTS
  42. from sympy.matrices import eye
  43. from sympy.utilities.exceptions import (sympy_deprecation_warning,
  44. SymPyDeprecationWarning,
  45. ignore_warnings)
  46. from sympy.utilities.decorator import memoize_property, deprecated
  47. from sympy.utilities.iterables import sift
  48. def deprecate_data():
  49. sympy_deprecation_warning(
  50. """
  51. The data attribute of TensorIndexType is deprecated. Use The
  52. replace_with_arrays() method instead.
  53. """,
  54. deprecated_since_version="1.4",
  55. active_deprecations_target="deprecated-tensorindextype-attrs",
  56. stacklevel=4,
  57. )
  58. def deprecate_fun_eval():
  59. sympy_deprecation_warning(
  60. """
  61. The Tensor.fun_eval() method is deprecated. Use
  62. Tensor.substitute_indices() instead.
  63. """,
  64. deprecated_since_version="1.5",
  65. active_deprecations_target="deprecated-tensor-fun-eval",
  66. stacklevel=4,
  67. )
  68. def deprecate_call():
  69. sympy_deprecation_warning(
  70. """
  71. Calling a tensor like Tensor(*indices) is deprecated. Use
  72. Tensor.substitute_indices() instead.
  73. """,
  74. deprecated_since_version="1.5",
  75. active_deprecations_target="deprecated-tensor-fun-eval",
  76. stacklevel=4,
  77. )
  78. class _IndexStructure(CantSympify):
  79. """
  80. This class handles the indices (free and dummy ones). It contains the
  81. algorithms to manage the dummy indices replacements and contractions of
  82. free indices under multiplications of tensor expressions, as well as stuff
  83. related to canonicalization sorting, getting the permutation of the
  84. expression and so on. It also includes tools to get the ``TensorIndex``
  85. objects corresponding to the given index structure.
  86. """
  87. def __init__(self, free, dum, index_types, indices, canon_bp=False):
  88. self.free = free
  89. self.dum = dum
  90. self.index_types = index_types
  91. self.indices = indices
  92. self._ext_rank = len(self.free) + 2*len(self.dum)
  93. self.dum.sort(key=lambda x: x[0])
  94. @staticmethod
  95. def from_indices(*indices):
  96. """
  97. Create a new ``_IndexStructure`` object from a list of ``indices``.
  98. Explanation
  99. ===========
  100. ``indices`` ``TensorIndex`` objects, the indices. Contractions are
  101. detected upon construction.
  102. Examples
  103. ========
  104. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, _IndexStructure
  105. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  106. >>> m0, m1, m2, m3 = tensor_indices('m0,m1,m2,m3', Lorentz)
  107. >>> _IndexStructure.from_indices(m0, m1, -m1, m3)
  108. _IndexStructure([(m0, 0), (m3, 3)], [(1, 2)], [Lorentz, Lorentz, Lorentz, Lorentz])
  109. """
  110. free, dum = _IndexStructure._free_dum_from_indices(*indices)
  111. index_types = [i.tensor_index_type for i in indices]
  112. indices = _IndexStructure._replace_dummy_names(indices, free, dum)
  113. return _IndexStructure(free, dum, index_types, indices)
  114. @staticmethod
  115. def from_components_free_dum(components, free, dum):
  116. index_types = []
  117. for component in components:
  118. index_types.extend(component.index_types)
  119. indices = _IndexStructure.generate_indices_from_free_dum_index_types(free, dum, index_types)
  120. return _IndexStructure(free, dum, index_types, indices)
  121. @staticmethod
  122. def _free_dum_from_indices(*indices):
  123. """
  124. Convert ``indices`` into ``free``, ``dum`` for single component tensor.
  125. Explanation
  126. ===========
  127. ``free`` list of tuples ``(index, pos, 0)``,
  128. where ``pos`` is the position of index in
  129. the list of indices formed by the component tensors
  130. ``dum`` list of tuples ``(pos_contr, pos_cov, 0, 0)``
  131. Examples
  132. ========
  133. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, \
  134. _IndexStructure
  135. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  136. >>> m0, m1, m2, m3 = tensor_indices('m0,m1,m2,m3', Lorentz)
  137. >>> _IndexStructure._free_dum_from_indices(m0, m1, -m1, m3)
  138. ([(m0, 0), (m3, 3)], [(1, 2)])
  139. """
  140. n = len(indices)
  141. if n == 1:
  142. return [(indices[0], 0)], []
  143. # find the positions of the free indices and of the dummy indices
  144. free = [True]*len(indices)
  145. index_dict = {}
  146. dum = []
  147. for i, index in enumerate(indices):
  148. name = index.name
  149. typ = index.tensor_index_type
  150. contr = index.is_up
  151. if (name, typ) in index_dict:
  152. # found a pair of dummy indices
  153. is_contr, pos = index_dict[(name, typ)]
  154. # check consistency and update free
  155. if is_contr:
  156. if contr:
  157. raise ValueError('two equal contravariant indices in slots %d and %d' %(pos, i))
  158. else:
  159. free[pos] = False
  160. free[i] = False
  161. else:
  162. if contr:
  163. free[pos] = False
  164. free[i] = False
  165. else:
  166. raise ValueError('two equal covariant indices in slots %d and %d' %(pos, i))
  167. if contr:
  168. dum.append((i, pos))
  169. else:
  170. dum.append((pos, i))
  171. else:
  172. index_dict[(name, typ)] = index.is_up, i
  173. free = [(index, i) for i, index in enumerate(indices) if free[i]]
  174. free.sort()
  175. return free, dum
  176. def get_indices(self):
  177. """
  178. Get a list of indices, creating new tensor indices to complete dummy indices.
  179. """
  180. return self.indices[:]
  181. @staticmethod
  182. def generate_indices_from_free_dum_index_types(free, dum, index_types):
  183. indices = [None]*(len(free)+2*len(dum))
  184. for idx, pos in free:
  185. indices[pos] = idx
  186. generate_dummy_name = _IndexStructure._get_generator_for_dummy_indices(free)
  187. for pos1, pos2 in dum:
  188. typ1 = index_types[pos1]
  189. indname = generate_dummy_name(typ1)
  190. indices[pos1] = TensorIndex(indname, typ1, True)
  191. indices[pos2] = TensorIndex(indname, typ1, False)
  192. return _IndexStructure._replace_dummy_names(indices, free, dum)
  193. @staticmethod
  194. def _get_generator_for_dummy_indices(free):
  195. cdt = defaultdict(int)
  196. # if the free indices have names with dummy_name, start with an
  197. # index higher than those for the dummy indices
  198. # to avoid name collisions
  199. for indx, ipos in free:
  200. if indx.name.split('_')[0] == indx.tensor_index_type.dummy_name:
  201. cdt[indx.tensor_index_type] = max(cdt[indx.tensor_index_type], int(indx.name.split('_')[1]) + 1)
  202. def dummy_name_gen(tensor_index_type):
  203. nd = str(cdt[tensor_index_type])
  204. cdt[tensor_index_type] += 1
  205. return tensor_index_type.dummy_name + '_' + nd
  206. return dummy_name_gen
  207. @staticmethod
  208. def _replace_dummy_names(indices, free, dum):
  209. dum.sort(key=lambda x: x[0])
  210. new_indices = list(indices)
  211. assert len(indices) == len(free) + 2*len(dum)
  212. generate_dummy_name = _IndexStructure._get_generator_for_dummy_indices(free)
  213. for ipos1, ipos2 in dum:
  214. typ1 = new_indices[ipos1].tensor_index_type
  215. indname = generate_dummy_name(typ1)
  216. new_indices[ipos1] = TensorIndex(indname, typ1, True)
  217. new_indices[ipos2] = TensorIndex(indname, typ1, False)
  218. return new_indices
  219. def get_free_indices(self) -> list[TensorIndex]:
  220. """
  221. Get a list of free indices.
  222. """
  223. # get sorted indices according to their position:
  224. free = sorted(self.free, key=lambda x: x[1])
  225. return [i[0] for i in free]
  226. def __str__(self):
  227. return "_IndexStructure({}, {}, {})".format(self.free, self.dum, self.index_types)
  228. def __repr__(self):
  229. return self.__str__()
  230. def _get_sorted_free_indices_for_canon(self):
  231. sorted_free = self.free[:]
  232. sorted_free.sort(key=lambda x: x[0])
  233. return sorted_free
  234. def _get_sorted_dum_indices_for_canon(self):
  235. return sorted(self.dum, key=lambda x: x[0])
  236. def _get_lexicographically_sorted_index_types(self):
  237. permutation = self.indices_canon_args()[0]
  238. index_types = [None]*self._ext_rank
  239. for i, it in enumerate(self.index_types):
  240. index_types[permutation(i)] = it
  241. return index_types
  242. def _get_lexicographically_sorted_indices(self):
  243. permutation = self.indices_canon_args()[0]
  244. indices = [None]*self._ext_rank
  245. for i, it in enumerate(self.indices):
  246. indices[permutation(i)] = it
  247. return indices
  248. def perm2tensor(self, g, is_canon_bp=False):
  249. """
  250. Returns a ``_IndexStructure`` instance corresponding to the permutation ``g``.
  251. Explanation
  252. ===========
  253. ``g`` permutation corresponding to the tensor in the representation
  254. used in canonicalization
  255. ``is_canon_bp`` if True, then ``g`` is the permutation
  256. corresponding to the canonical form of the tensor
  257. """
  258. sorted_free = [i[0] for i in self._get_sorted_free_indices_for_canon()]
  259. lex_index_types = self._get_lexicographically_sorted_index_types()
  260. lex_indices = self._get_lexicographically_sorted_indices()
  261. nfree = len(sorted_free)
  262. rank = self._ext_rank
  263. dum = [[None]*2 for i in range((rank - nfree)//2)]
  264. free = []
  265. index_types = [None]*rank
  266. indices = [None]*rank
  267. for i in range(rank):
  268. gi = g[i]
  269. index_types[i] = lex_index_types[gi]
  270. indices[i] = lex_indices[gi]
  271. if gi < nfree:
  272. ind = sorted_free[gi]
  273. assert index_types[i] == sorted_free[gi].tensor_index_type
  274. free.append((ind, i))
  275. else:
  276. j = gi - nfree
  277. idum, cov = divmod(j, 2)
  278. if cov:
  279. dum[idum][1] = i
  280. else:
  281. dum[idum][0] = i
  282. dum = [tuple(x) for x in dum]
  283. return _IndexStructure(free, dum, index_types, indices)
  284. def indices_canon_args(self):
  285. """
  286. Returns ``(g, dummies, msym, v)``, the entries of ``canonicalize``
  287. See ``canonicalize`` in ``tensor_can.py`` in combinatorics module.
  288. """
  289. # to be called after sorted_components
  290. from sympy.combinatorics.permutations import _af_new
  291. n = self._ext_rank
  292. g = [None]*n + [n, n+1]
  293. # Converts the symmetry of the metric into msym from .canonicalize()
  294. # method in the combinatorics module
  295. def metric_symmetry_to_msym(metric):
  296. if metric is None:
  297. return None
  298. sym = metric.symmetry
  299. if sym == TensorSymmetry.fully_symmetric(2):
  300. return 0
  301. if sym == TensorSymmetry.fully_symmetric(-2):
  302. return 1
  303. return None
  304. # ordered indices: first the free indices, ordered by types
  305. # then the dummy indices, ordered by types and contravariant before
  306. # covariant
  307. # g[position in tensor] = position in ordered indices
  308. for i, (indx, ipos) in enumerate(self._get_sorted_free_indices_for_canon()):
  309. g[ipos] = i
  310. pos = len(self.free)
  311. j = len(self.free)
  312. dummies = []
  313. prev = None
  314. a = []
  315. msym = []
  316. for ipos1, ipos2 in self._get_sorted_dum_indices_for_canon():
  317. g[ipos1] = j
  318. g[ipos2] = j + 1
  319. j += 2
  320. typ = self.index_types[ipos1]
  321. if typ != prev:
  322. if a:
  323. dummies.append(a)
  324. a = [pos, pos + 1]
  325. prev = typ
  326. msym.append(metric_symmetry_to_msym(typ.metric))
  327. else:
  328. a.extend([pos, pos + 1])
  329. pos += 2
  330. if a:
  331. dummies.append(a)
  332. return _af_new(g), dummies, msym
  333. def components_canon_args(components):
  334. numtyp = []
  335. prev = None
  336. for t in components:
  337. if t == prev:
  338. numtyp[-1][1] += 1
  339. else:
  340. prev = t
  341. numtyp.append([prev, 1])
  342. v = []
  343. for h, n in numtyp:
  344. if h.comm in (0, 1):
  345. comm = h.comm
  346. else:
  347. comm = TensorManager.get_comm(h.comm, h.comm)
  348. v.append((h.symmetry.base, h.symmetry.generators, n, comm))
  349. return v
  350. class _TensorDataLazyEvaluator(CantSympify):
  351. """
  352. EXPERIMENTAL: do not rely on this class, it may change without deprecation
  353. warnings in future versions of SymPy.
  354. Explanation
  355. ===========
  356. This object contains the logic to associate components data to a tensor
  357. expression. Components data are set via the ``.data`` property of tensor
  358. expressions, is stored inside this class as a mapping between the tensor
  359. expression and the ``ndarray``.
  360. Computations are executed lazily: whereas the tensor expressions can have
  361. contractions, tensor products, and additions, components data are not
  362. computed until they are accessed by reading the ``.data`` property
  363. associated to the tensor expression.
  364. """
  365. _substitutions_dict: dict[Any, Any] = {}
  366. _substitutions_dict_tensmul: dict[Any, Any] = {}
  367. def __getitem__(self, key):
  368. dat = self._get(key)
  369. if dat is None:
  370. return None
  371. from .array import NDimArray
  372. if not isinstance(dat, NDimArray):
  373. return dat
  374. if dat.rank() == 0:
  375. return dat[()]
  376. elif dat.rank() == 1 and len(dat) == 1:
  377. return dat[0]
  378. return dat
  379. def _get(self, key):
  380. """
  381. Retrieve ``data`` associated with ``key``.
  382. Explanation
  383. ===========
  384. This algorithm looks into ``self._substitutions_dict`` for all
  385. ``TensorHead`` in the ``TensExpr`` (or just ``TensorHead`` if key is a
  386. TensorHead instance). It reconstructs the components data that the
  387. tensor expression should have by performing on components data the
  388. operations that correspond to the abstract tensor operations applied.
  389. Metric tensor is handled in a different manner: it is pre-computed in
  390. ``self._substitutions_dict_tensmul``.
  391. """
  392. if key in self._substitutions_dict:
  393. return self._substitutions_dict[key]
  394. if isinstance(key, TensorHead):
  395. return None
  396. if isinstance(key, Tensor):
  397. # special case to handle metrics. Metric tensors cannot be
  398. # constructed through contraction by the metric, their
  399. # components show if they are a matrix or its inverse.
  400. signature = tuple([i.is_up for i in key.get_indices()])
  401. srch = (key.component,) + signature
  402. if srch in self._substitutions_dict_tensmul:
  403. return self._substitutions_dict_tensmul[srch]
  404. array_list = [self.data_from_tensor(key)]
  405. return self.data_contract_dum(array_list, key.dum, key.ext_rank)
  406. if isinstance(key, TensMul):
  407. tensmul_args = key.args
  408. if len(tensmul_args) == 1 and len(tensmul_args[0].components) == 1:
  409. # special case to handle metrics. Metric tensors cannot be
  410. # constructed through contraction by the metric, their
  411. # components show if they are a matrix or its inverse.
  412. signature = tuple([i.is_up for i in tensmul_args[0].get_indices()])
  413. srch = (tensmul_args[0].components[0],) + signature
  414. if srch in self._substitutions_dict_tensmul:
  415. return self._substitutions_dict_tensmul[srch]
  416. #data_list = [self.data_from_tensor(i) for i in tensmul_args if isinstance(i, TensExpr)]
  417. data_list = [self.data_from_tensor(i) if isinstance(i, Tensor) else i.data for i in tensmul_args if isinstance(i, TensExpr)]
  418. coeff = prod([i for i in tensmul_args if not isinstance(i, TensExpr)])
  419. if all(i is None for i in data_list):
  420. return None
  421. if any(i is None for i in data_list):
  422. raise ValueError("Mixing tensors with associated components "\
  423. "data with tensors without components data")
  424. data_result = self.data_contract_dum(data_list, key.dum, key.ext_rank)
  425. return coeff*data_result
  426. if isinstance(key, TensAdd):
  427. data_list = []
  428. free_args_list = []
  429. for arg in key.args:
  430. if isinstance(arg, TensExpr):
  431. data_list.append(arg.data)
  432. free_args_list.append([x[0] for x in arg.free])
  433. else:
  434. data_list.append(arg)
  435. free_args_list.append([])
  436. if all(i is None for i in data_list):
  437. return None
  438. if any(i is None for i in data_list):
  439. raise ValueError("Mixing tensors with associated components "\
  440. "data with tensors without components data")
  441. sum_list = []
  442. from .array import permutedims
  443. for data, free_args in zip(data_list, free_args_list):
  444. if len(free_args) < 2:
  445. sum_list.append(data)
  446. else:
  447. free_args_pos = {y: x for x, y in enumerate(free_args)}
  448. axes = [free_args_pos[arg] for arg in key.free_args]
  449. sum_list.append(permutedims(data, axes))
  450. return reduce(lambda x, y: x+y, sum_list)
  451. return None
  452. @staticmethod
  453. def data_contract_dum(ndarray_list, dum, ext_rank):
  454. from .array import tensorproduct, tensorcontraction, MutableDenseNDimArray
  455. arrays = list(map(MutableDenseNDimArray, ndarray_list))
  456. prodarr = tensorproduct(*arrays)
  457. return tensorcontraction(prodarr, *dum)
  458. def data_tensorhead_from_tensmul(self, data, tensmul, tensorhead):
  459. """
  460. This method is used when assigning components data to a ``TensMul``
  461. object, it converts components data to a fully contravariant ndarray,
  462. which is then stored according to the ``TensorHead`` key.
  463. """
  464. if data is None:
  465. return None
  466. return self._correct_signature_from_indices(
  467. data,
  468. tensmul.get_indices(),
  469. tensmul.free,
  470. tensmul.dum,
  471. True)
  472. def data_from_tensor(self, tensor):
  473. """
  474. This method corrects the components data to the right signature
  475. (covariant/contravariant) using the metric associated with each
  476. ``TensorIndexType``.
  477. """
  478. tensorhead = tensor.component
  479. if tensorhead.data is None:
  480. return None
  481. return self._correct_signature_from_indices(
  482. tensorhead.data,
  483. tensor.get_indices(),
  484. tensor.free,
  485. tensor.dum)
  486. def _assign_data_to_tensor_expr(self, key, data):
  487. if isinstance(key, TensAdd):
  488. raise ValueError('cannot assign data to TensAdd')
  489. # here it is assumed that `key` is a `TensMul` instance.
  490. if len(key.components) != 1:
  491. raise ValueError('cannot assign data to TensMul with multiple components')
  492. tensorhead = key.components[0]
  493. newdata = self.data_tensorhead_from_tensmul(data, key, tensorhead)
  494. return tensorhead, newdata
  495. def _check_permutations_on_data(self, tens, data):
  496. from .array import permutedims
  497. from .array.arrayop import Flatten
  498. if isinstance(tens, TensorHead):
  499. rank = tens.rank
  500. generators = tens.symmetry.generators
  501. elif isinstance(tens, Tensor):
  502. rank = tens.rank
  503. generators = tens.components[0].symmetry.generators
  504. elif isinstance(tens, TensorIndexType):
  505. rank = tens.metric.rank
  506. generators = tens.metric.symmetry.generators
  507. # Every generator is a permutation, check that by permuting the array
  508. # by that permutation, the array will be the same, except for a
  509. # possible sign change if the permutation admits it.
  510. for gener in generators:
  511. sign_change = +1 if (gener(rank) == rank) else -1
  512. data_swapped = data
  513. last_data = data
  514. permute_axes = list(map(gener, range(rank)))
  515. # the order of a permutation is the number of times to get the
  516. # identity by applying that permutation.
  517. for i in range(gener.order()-1):
  518. data_swapped = permutedims(data_swapped, permute_axes)
  519. # if any value in the difference array is non-zero, raise an error:
  520. if any(Flatten(last_data - sign_change*data_swapped)):
  521. raise ValueError("Component data symmetry structure error")
  522. last_data = data_swapped
  523. def __setitem__(self, key, value):
  524. """
  525. Set the components data of a tensor object/expression.
  526. Explanation
  527. ===========
  528. Components data are transformed to the all-contravariant form and stored
  529. with the corresponding ``TensorHead`` object. If a ``TensorHead`` object
  530. cannot be uniquely identified, it will raise an error.
  531. """
  532. data = _TensorDataLazyEvaluator.parse_data(value)
  533. self._check_permutations_on_data(key, data)
  534. # TensorHead and TensorIndexType can be assigned data directly, while
  535. # TensMul must first convert data to a fully contravariant form, and
  536. # assign it to its corresponding TensorHead single component.
  537. if not isinstance(key, (TensorHead, TensorIndexType)):
  538. key, data = self._assign_data_to_tensor_expr(key, data)
  539. if isinstance(key, TensorHead):
  540. for dim, indextype in zip(data.shape, key.index_types):
  541. if indextype.data is None:
  542. raise ValueError("index type {} has no components data"\
  543. " associated (needed to raise/lower index)".format(indextype))
  544. if not indextype.dim.is_number:
  545. continue
  546. if dim != indextype.dim:
  547. raise ValueError("wrong dimension of ndarray")
  548. self._substitutions_dict[key] = data
  549. def __delitem__(self, key):
  550. del self._substitutions_dict[key]
  551. def __contains__(self, key):
  552. return key in self._substitutions_dict
  553. def add_metric_data(self, metric, data):
  554. """
  555. Assign data to the ``metric`` tensor. The metric tensor behaves in an
  556. anomalous way when raising and lowering indices.
  557. Explanation
  558. ===========
  559. A fully covariant metric is the inverse transpose of the fully
  560. contravariant metric (it is meant matrix inverse). If the metric is
  561. symmetric, the transpose is not necessary and mixed
  562. covariant/contravariant metrics are Kronecker deltas.
  563. """
  564. # hard assignment, data should not be added to `TensorHead` for metric:
  565. # the problem with `TensorHead` is that the metric is anomalous, i.e.
  566. # raising and lowering the index means considering the metric or its
  567. # inverse, this is not the case for other tensors.
  568. self._substitutions_dict_tensmul[metric, True, True] = data
  569. inverse_transpose = self.inverse_transpose_matrix(data)
  570. # in symmetric spaces, the transpose is the same as the original matrix,
  571. # the full covariant metric tensor is the inverse transpose, so this
  572. # code will be able to handle non-symmetric metrics.
  573. self._substitutions_dict_tensmul[metric, False, False] = inverse_transpose
  574. # now mixed cases, these are identical to the unit matrix if the metric
  575. # is symmetric.
  576. m = data.tomatrix()
  577. invt = inverse_transpose.tomatrix()
  578. self._substitutions_dict_tensmul[metric, True, False] = m * invt
  579. self._substitutions_dict_tensmul[metric, False, True] = invt * m
  580. @staticmethod
  581. def _flip_index_by_metric(data, metric, pos):
  582. from .array import tensorproduct, tensorcontraction
  583. mdim = metric.rank()
  584. ddim = data.rank()
  585. if pos == 0:
  586. data = tensorcontraction(
  587. tensorproduct(
  588. metric,
  589. data
  590. ),
  591. (1, mdim+pos)
  592. )
  593. else:
  594. data = tensorcontraction(
  595. tensorproduct(
  596. data,
  597. metric
  598. ),
  599. (pos, ddim)
  600. )
  601. return data
  602. @staticmethod
  603. def inverse_matrix(ndarray):
  604. m = ndarray.tomatrix().inv()
  605. return _TensorDataLazyEvaluator.parse_data(m)
  606. @staticmethod
  607. def inverse_transpose_matrix(ndarray):
  608. m = ndarray.tomatrix().inv().T
  609. return _TensorDataLazyEvaluator.parse_data(m)
  610. @staticmethod
  611. def _correct_signature_from_indices(data, indices, free, dum, inverse=False):
  612. """
  613. Utility function to correct the values inside the components data
  614. ndarray according to whether indices are covariant or contravariant.
  615. It uses the metric matrix to lower values of covariant indices.
  616. """
  617. # change the ndarray values according covariantness/contravariantness of the indices
  618. # use the metric
  619. for i, indx in enumerate(indices):
  620. if not indx.is_up and not inverse:
  621. data = _TensorDataLazyEvaluator._flip_index_by_metric(data, indx.tensor_index_type.data, i)
  622. elif not indx.is_up and inverse:
  623. data = _TensorDataLazyEvaluator._flip_index_by_metric(
  624. data,
  625. _TensorDataLazyEvaluator.inverse_matrix(indx.tensor_index_type.data),
  626. i
  627. )
  628. return data
  629. @staticmethod
  630. def _sort_data_axes(old, new):
  631. from .array import permutedims
  632. new_data = old.data.copy()
  633. old_free = [i[0] for i in old.free]
  634. new_free = [i[0] for i in new.free]
  635. for i in range(len(new_free)):
  636. for j in range(i, len(old_free)):
  637. if old_free[j] == new_free[i]:
  638. old_free[i], old_free[j] = old_free[j], old_free[i]
  639. new_data = permutedims(new_data, (i, j))
  640. break
  641. return new_data
  642. @staticmethod
  643. def add_rearrange_tensmul_parts(new_tensmul, old_tensmul):
  644. def sorted_compo():
  645. return _TensorDataLazyEvaluator._sort_data_axes(old_tensmul, new_tensmul)
  646. _TensorDataLazyEvaluator._substitutions_dict[new_tensmul] = sorted_compo()
  647. @staticmethod
  648. def parse_data(data):
  649. """
  650. Transform ``data`` to array. The parameter ``data`` may
  651. contain data in various formats, e.g. nested lists, SymPy ``Matrix``,
  652. and so on.
  653. Examples
  654. ========
  655. >>> from sympy.tensor.tensor import _TensorDataLazyEvaluator
  656. >>> _TensorDataLazyEvaluator.parse_data([1, 3, -6, 12])
  657. [1, 3, -6, 12]
  658. >>> _TensorDataLazyEvaluator.parse_data([[1, 2], [4, 7]])
  659. [[1, 2], [4, 7]]
  660. """
  661. from .array import MutableDenseNDimArray
  662. if not isinstance(data, MutableDenseNDimArray):
  663. if len(data) == 2 and hasattr(data[0], '__call__'):
  664. data = MutableDenseNDimArray(data[0], data[1])
  665. else:
  666. data = MutableDenseNDimArray(data)
  667. return data
  668. _tensor_data_substitution_dict = _TensorDataLazyEvaluator()
  669. class _TensorManager:
  670. """
  671. Class to manage tensor properties.
  672. Notes
  673. =====
  674. Tensors belong to tensor commutation groups; each group has a label
  675. ``comm``; there are predefined labels:
  676. ``0`` tensors commuting with any other tensor
  677. ``1`` tensors anticommuting among themselves
  678. ``2`` tensors not commuting, apart with those with ``comm=0``
  679. Other groups can be defined using ``set_comm``; tensors in those
  680. groups commute with those with ``comm=0``; by default they
  681. do not commute with any other group.
  682. """
  683. def __init__(self):
  684. self._comm_init()
  685. def _comm_init(self):
  686. self._comm = [{} for i in range(3)]
  687. for i in range(3):
  688. self._comm[0][i] = 0
  689. self._comm[i][0] = 0
  690. self._comm[1][1] = 1
  691. self._comm[2][1] = None
  692. self._comm[1][2] = None
  693. self._comm_symbols2i = {0:0, 1:1, 2:2}
  694. self._comm_i2symbol = {0:0, 1:1, 2:2}
  695. @property
  696. def comm(self):
  697. return self._comm
  698. def comm_symbols2i(self, i):
  699. """
  700. Get the commutation group number corresponding to ``i``.
  701. ``i`` can be a symbol or a number or a string.
  702. If ``i`` is not already defined its commutation group number
  703. is set.
  704. """
  705. if i not in self._comm_symbols2i:
  706. n = len(self._comm)
  707. self._comm.append({})
  708. self._comm[n][0] = 0
  709. self._comm[0][n] = 0
  710. self._comm_symbols2i[i] = n
  711. self._comm_i2symbol[n] = i
  712. return n
  713. return self._comm_symbols2i[i]
  714. def comm_i2symbol(self, i):
  715. """
  716. Returns the symbol corresponding to the commutation group number.
  717. """
  718. return self._comm_i2symbol[i]
  719. def set_comm(self, i, j, c):
  720. """
  721. Set the commutation parameter ``c`` for commutation groups ``i, j``.
  722. Parameters
  723. ==========
  724. i, j : symbols representing commutation groups
  725. c : group commutation number
  726. Notes
  727. =====
  728. ``i, j`` can be symbols, strings or numbers,
  729. apart from ``0, 1`` and ``2`` which are reserved respectively
  730. for commuting, anticommuting tensors and tensors not commuting
  731. with any other group apart with the commuting tensors.
  732. For the remaining cases, use this method to set the commutation rules;
  733. by default ``c=None``.
  734. The group commutation number ``c`` is assigned in correspondence
  735. to the group commutation symbols; it can be
  736. 0 commuting
  737. 1 anticommuting
  738. None no commutation property
  739. Examples
  740. ========
  741. ``G`` and ``GH`` do not commute with themselves and commute with
  742. each other; A is commuting.
  743. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, TensorManager, TensorSymmetry
  744. >>> Lorentz = TensorIndexType('Lorentz')
  745. >>> i0,i1,i2,i3,i4 = tensor_indices('i0:5', Lorentz)
  746. >>> A = TensorHead('A', [Lorentz])
  747. >>> G = TensorHead('G', [Lorentz], TensorSymmetry.no_symmetry(1), 'Gcomm')
  748. >>> GH = TensorHead('GH', [Lorentz], TensorSymmetry.no_symmetry(1), 'GHcomm')
  749. >>> TensorManager.set_comm('Gcomm', 'GHcomm', 0)
  750. >>> (GH(i1)*G(i0)).canon_bp()
  751. G(i0)*GH(i1)
  752. >>> (G(i1)*G(i0)).canon_bp()
  753. G(i1)*G(i0)
  754. >>> (G(i1)*A(i0)).canon_bp()
  755. A(i0)*G(i1)
  756. """
  757. if c not in (0, 1, None):
  758. raise ValueError('`c` can assume only the values 0, 1 or None')
  759. if i not in self._comm_symbols2i:
  760. n = len(self._comm)
  761. self._comm.append({})
  762. self._comm[n][0] = 0
  763. self._comm[0][n] = 0
  764. self._comm_symbols2i[i] = n
  765. self._comm_i2symbol[n] = i
  766. if j not in self._comm_symbols2i:
  767. n = len(self._comm)
  768. self._comm.append({})
  769. self._comm[0][n] = 0
  770. self._comm[n][0] = 0
  771. self._comm_symbols2i[j] = n
  772. self._comm_i2symbol[n] = j
  773. ni = self._comm_symbols2i[i]
  774. nj = self._comm_symbols2i[j]
  775. self._comm[ni][nj] = c
  776. self._comm[nj][ni] = c
  777. def set_comms(self, *args):
  778. """
  779. Set the commutation group numbers ``c`` for symbols ``i, j``.
  780. Parameters
  781. ==========
  782. args : sequence of ``(i, j, c)``
  783. """
  784. for i, j, c in args:
  785. self.set_comm(i, j, c)
  786. def get_comm(self, i, j):
  787. """
  788. Return the commutation parameter for commutation group numbers ``i, j``
  789. see ``_TensorManager.set_comm``
  790. """
  791. return self._comm[i].get(j, 0 if i == 0 or j == 0 else None)
  792. def clear(self):
  793. """
  794. Clear the TensorManager.
  795. """
  796. self._comm_init()
  797. TensorManager = _TensorManager()
  798. class TensorIndexType(Basic):
  799. """
  800. A TensorIndexType is characterized by its name and its metric.
  801. Parameters
  802. ==========
  803. name : name of the tensor type
  804. dummy_name : name of the head of dummy indices
  805. dim : dimension, it can be a symbol or an integer or ``None``
  806. eps_dim : dimension of the epsilon tensor
  807. metric_symmetry : integer that denotes metric symmetry or ``None`` for no metric
  808. metric_name : string with the name of the metric tensor
  809. Attributes
  810. ==========
  811. ``metric`` : the metric tensor
  812. ``delta`` : ``Kronecker delta``
  813. ``epsilon`` : the ``Levi-Civita epsilon`` tensor
  814. ``data`` : (deprecated) a property to add ``ndarray`` values, to work in a specified basis.
  815. Notes
  816. =====
  817. The possible values of the ``metric_symmetry`` parameter are:
  818. ``1`` : metric tensor is fully symmetric
  819. ``0`` : metric tensor possesses no index symmetry
  820. ``-1`` : metric tensor is fully antisymmetric
  821. ``None``: there is no metric tensor (metric equals to ``None``)
  822. The metric is assumed to be symmetric by default. It can also be set
  823. to a custom tensor by the ``.set_metric()`` method.
  824. If there is a metric the metric is used to raise and lower indices.
  825. In the case of non-symmetric metric, the following raising and
  826. lowering conventions will be adopted:
  827. ``psi(a) = g(a, b)*psi(-b); chi(-a) = chi(b)*g(-b, -a)``
  828. From these it is easy to find:
  829. ``g(-a, b) = delta(-a, b)``
  830. where ``delta(-a, b) = delta(b, -a)`` is the ``Kronecker delta``
  831. (see ``TensorIndex`` for the conventions on indices).
  832. For antisymmetric metrics there is also the following equality:
  833. ``g(a, -b) = -delta(a, -b)``
  834. If there is no metric it is not possible to raise or lower indices;
  835. e.g. the index of the defining representation of ``SU(N)``
  836. is 'covariant' and the conjugate representation is
  837. 'contravariant'; for ``N > 2`` they are linearly independent.
  838. ``eps_dim`` is by default equal to ``dim``, if the latter is an integer;
  839. else it can be assigned (for use in naive dimensional regularization);
  840. if ``eps_dim`` is not an integer ``epsilon`` is ``None``.
  841. Examples
  842. ========
  843. >>> from sympy.tensor.tensor import TensorIndexType
  844. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  845. >>> Lorentz.metric
  846. metric(Lorentz,Lorentz)
  847. """
  848. def __new__(cls, name, dummy_name=None, dim=None, eps_dim=None,
  849. metric_symmetry=1, metric_name='metric', **kwargs):
  850. if 'dummy_fmt' in kwargs:
  851. dummy_fmt = kwargs['dummy_fmt']
  852. sympy_deprecation_warning(
  853. f"""
  854. The dummy_fmt keyword to TensorIndexType is deprecated. Use
  855. dummy_name={dummy_fmt} instead.
  856. """,
  857. deprecated_since_version="1.5",
  858. active_deprecations_target="deprecated-tensorindextype-dummy-fmt",
  859. )
  860. dummy_name = dummy_fmt
  861. if isinstance(name, str):
  862. name = Symbol(name)
  863. if dummy_name is None:
  864. dummy_name = str(name)[0]
  865. if isinstance(dummy_name, str):
  866. dummy_name = Symbol(dummy_name)
  867. if dim is None:
  868. dim = Symbol("dim_" + dummy_name.name)
  869. else:
  870. dim = sympify(dim)
  871. if eps_dim is None:
  872. eps_dim = dim
  873. else:
  874. eps_dim = sympify(eps_dim)
  875. metric_symmetry = sympify(metric_symmetry)
  876. if isinstance(metric_name, str):
  877. metric_name = Symbol(metric_name)
  878. if 'metric' in kwargs:
  879. SymPyDeprecationWarning(
  880. """
  881. The 'metric' keyword argument to TensorIndexType is
  882. deprecated. Use the 'metric_symmetry' keyword argument or the
  883. TensorIndexType.set_metric() method instead.
  884. """,
  885. deprecated_since_version="1.5",
  886. active_deprecations_target="deprecated-tensorindextype-metric",
  887. )
  888. metric = kwargs.get('metric')
  889. if metric is not None:
  890. if metric in (True, False, 0, 1):
  891. metric_name = 'metric'
  892. #metric_antisym = metric
  893. else:
  894. metric_name = metric.name
  895. #metric_antisym = metric.antisym
  896. if metric:
  897. metric_symmetry = -1
  898. else:
  899. metric_symmetry = 1
  900. obj = Basic.__new__(cls, name, dummy_name, dim, eps_dim,
  901. metric_symmetry, metric_name)
  902. obj._autogenerated = []
  903. return obj
  904. @property
  905. def name(self):
  906. return self.args[0].name
  907. @property
  908. def dummy_name(self):
  909. return self.args[1].name
  910. @property
  911. def dim(self):
  912. return self.args[2]
  913. @property
  914. def eps_dim(self):
  915. return self.args[3]
  916. @memoize_property
  917. def metric(self):
  918. metric_symmetry = self.args[4]
  919. metric_name = self.args[5]
  920. if metric_symmetry is None:
  921. return None
  922. if metric_symmetry == 0:
  923. symmetry = TensorSymmetry.no_symmetry(2)
  924. elif metric_symmetry == 1:
  925. symmetry = TensorSymmetry.fully_symmetric(2)
  926. elif metric_symmetry == -1:
  927. symmetry = TensorSymmetry.fully_symmetric(-2)
  928. return TensorHead(metric_name, [self]*2, symmetry)
  929. @memoize_property
  930. def delta(self):
  931. return TensorHead('KD', [self]*2, TensorSymmetry.fully_symmetric(2))
  932. @memoize_property
  933. def epsilon(self):
  934. if not isinstance(self.eps_dim, (SYMPY_INTS, Integer)):
  935. return None
  936. symmetry = TensorSymmetry.fully_symmetric(-self.eps_dim)
  937. return TensorHead('Eps', [self]*self.eps_dim, symmetry)
  938. def set_metric(self, tensor):
  939. self._metric = tensor
  940. def __lt__(self, other):
  941. return self.name < other.name
  942. def __str__(self):
  943. return self.name
  944. __repr__ = __str__
  945. # Everything below this line is deprecated
  946. @property
  947. def data(self):
  948. deprecate_data()
  949. with ignore_warnings(SymPyDeprecationWarning):
  950. return _tensor_data_substitution_dict[self]
  951. @data.setter
  952. def data(self, data):
  953. deprecate_data()
  954. # This assignment is a bit controversial, should metric components be assigned
  955. # to the metric only or also to the TensorIndexType object? The advantage here
  956. # is the ability to assign a 1D array and transform it to a 2D diagonal array.
  957. from .array import MutableDenseNDimArray
  958. data = _TensorDataLazyEvaluator.parse_data(data)
  959. if data.rank() > 2:
  960. raise ValueError("data have to be of rank 1 (diagonal metric) or 2.")
  961. if data.rank() == 1:
  962. if self.dim.is_number:
  963. nda_dim = data.shape[0]
  964. if nda_dim != self.dim:
  965. raise ValueError("Dimension mismatch")
  966. dim = data.shape[0]
  967. newndarray = MutableDenseNDimArray.zeros(dim, dim)
  968. for i, val in enumerate(data):
  969. newndarray[i, i] = val
  970. data = newndarray
  971. dim1, dim2 = data.shape
  972. if dim1 != dim2:
  973. raise ValueError("Non-square matrix tensor.")
  974. if self.dim.is_number:
  975. if self.dim != dim1:
  976. raise ValueError("Dimension mismatch")
  977. _tensor_data_substitution_dict[self] = data
  978. _tensor_data_substitution_dict.add_metric_data(self.metric, data)
  979. with ignore_warnings(SymPyDeprecationWarning):
  980. delta = self.get_kronecker_delta()
  981. i1 = TensorIndex('i1', self)
  982. i2 = TensorIndex('i2', self)
  983. with ignore_warnings(SymPyDeprecationWarning):
  984. delta(i1, -i2).data = _TensorDataLazyEvaluator.parse_data(eye(dim1))
  985. @data.deleter
  986. def data(self):
  987. deprecate_data()
  988. with ignore_warnings(SymPyDeprecationWarning):
  989. if self in _tensor_data_substitution_dict:
  990. del _tensor_data_substitution_dict[self]
  991. if self.metric in _tensor_data_substitution_dict:
  992. del _tensor_data_substitution_dict[self.metric]
  993. @deprecated(
  994. """
  995. The TensorIndexType.get_kronecker_delta() method is deprecated. Use
  996. the TensorIndexType.delta attribute instead.
  997. """,
  998. deprecated_since_version="1.5",
  999. active_deprecations_target="deprecated-tensorindextype-methods",
  1000. )
  1001. def get_kronecker_delta(self):
  1002. sym2 = TensorSymmetry(get_symmetric_group_sgs(2))
  1003. delta = TensorHead('KD', [self]*2, sym2)
  1004. return delta
  1005. @deprecated(
  1006. """
  1007. The TensorIndexType.get_epsilon() method is deprecated. Use
  1008. the TensorIndexType.epsilon attribute instead.
  1009. """,
  1010. deprecated_since_version="1.5",
  1011. active_deprecations_target="deprecated-tensorindextype-methods",
  1012. )
  1013. def get_epsilon(self):
  1014. if not isinstance(self._eps_dim, (SYMPY_INTS, Integer)):
  1015. return None
  1016. sym = TensorSymmetry(get_symmetric_group_sgs(self._eps_dim, 1))
  1017. epsilon = TensorHead('Eps', [self]*self._eps_dim, sym)
  1018. return epsilon
  1019. def _components_data_full_destroy(self):
  1020. """
  1021. EXPERIMENTAL: do not rely on this API method.
  1022. This destroys components data associated to the ``TensorIndexType``, if
  1023. any, specifically:
  1024. * metric tensor data
  1025. * Kronecker tensor data
  1026. """
  1027. if self in _tensor_data_substitution_dict:
  1028. del _tensor_data_substitution_dict[self]
  1029. def delete_tensmul_data(key):
  1030. if key in _tensor_data_substitution_dict._substitutions_dict_tensmul:
  1031. del _tensor_data_substitution_dict._substitutions_dict_tensmul[key]
  1032. # delete metric data:
  1033. delete_tensmul_data((self.metric, True, True))
  1034. delete_tensmul_data((self.metric, True, False))
  1035. delete_tensmul_data((self.metric, False, True))
  1036. delete_tensmul_data((self.metric, False, False))
  1037. # delete delta tensor data:
  1038. delta = self.get_kronecker_delta()
  1039. if delta in _tensor_data_substitution_dict:
  1040. del _tensor_data_substitution_dict[delta]
  1041. class TensorIndex(Basic):
  1042. """
  1043. Represents a tensor index
  1044. Parameters
  1045. ==========
  1046. name : name of the index, or ``True`` if you want it to be automatically assigned
  1047. tensor_index_type : ``TensorIndexType`` of the index
  1048. is_up : flag for contravariant index (is_up=True by default)
  1049. Attributes
  1050. ==========
  1051. ``name``
  1052. ``tensor_index_type``
  1053. ``is_up``
  1054. Notes
  1055. =====
  1056. Tensor indices are contracted with the Einstein summation convention.
  1057. An index can be in contravariant or in covariant form; in the latter
  1058. case it is represented prepending a ``-`` to the index name. Adding
  1059. ``-`` to a covariant (is_up=False) index makes it contravariant.
  1060. Dummy indices have a name with head given by
  1061. ``tensor_inde_type.dummy_name`` with underscore and a number.
  1062. Similar to ``symbols`` multiple contravariant indices can be created
  1063. at once using ``tensor_indices(s, typ)``, where ``s`` is a string
  1064. of names.
  1065. Examples
  1066. ========
  1067. >>> from sympy.tensor.tensor import TensorIndexType, TensorIndex, TensorHead, tensor_indices
  1068. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  1069. >>> mu = TensorIndex('mu', Lorentz, is_up=False)
  1070. >>> nu, rho = tensor_indices('nu, rho', Lorentz)
  1071. >>> A = TensorHead('A', [Lorentz, Lorentz])
  1072. >>> A(mu, nu)
  1073. A(-mu, nu)
  1074. >>> A(-mu, -rho)
  1075. A(mu, -rho)
  1076. >>> A(mu, -mu)
  1077. A(-L_0, L_0)
  1078. """
  1079. def __new__(cls, name, tensor_index_type, is_up=True):
  1080. if isinstance(name, str):
  1081. name_symbol = Symbol(name)
  1082. elif isinstance(name, Symbol):
  1083. name_symbol = name
  1084. elif name is True:
  1085. name = "_i{}".format(len(tensor_index_type._autogenerated))
  1086. name_symbol = Symbol(name)
  1087. tensor_index_type._autogenerated.append(name_symbol)
  1088. else:
  1089. raise ValueError("invalid name")
  1090. is_up = sympify(is_up)
  1091. return Basic.__new__(cls, name_symbol, tensor_index_type, is_up)
  1092. @property
  1093. def name(self):
  1094. return self.args[0].name
  1095. @property
  1096. def tensor_index_type(self):
  1097. return self.args[1]
  1098. @property
  1099. def is_up(self):
  1100. return self.args[2]
  1101. def _print(self):
  1102. s = self.name
  1103. if not self.is_up:
  1104. s = '-%s' % s
  1105. return s
  1106. def __lt__(self, other):
  1107. return ((self.tensor_index_type, self.name) <
  1108. (other.tensor_index_type, other.name))
  1109. def __neg__(self):
  1110. t1 = TensorIndex(self.name, self.tensor_index_type,
  1111. (not self.is_up))
  1112. return t1
  1113. def tensor_indices(s, typ):
  1114. """
  1115. Returns list of tensor indices given their names and their types.
  1116. Parameters
  1117. ==========
  1118. s : string of comma separated names of indices
  1119. typ : ``TensorIndexType`` of the indices
  1120. Examples
  1121. ========
  1122. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices
  1123. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  1124. >>> a, b, c, d = tensor_indices('a,b,c,d', Lorentz)
  1125. """
  1126. if isinstance(s, str):
  1127. a = [x.name for x in symbols(s, seq=True)]
  1128. else:
  1129. raise ValueError('expecting a string')
  1130. tilist = [TensorIndex(i, typ) for i in a]
  1131. if len(tilist) == 1:
  1132. return tilist[0]
  1133. return tilist
  1134. class TensorSymmetry(Basic):
  1135. """
  1136. Monoterm symmetry of a tensor (i.e. any symmetric or anti-symmetric
  1137. index permutation). For the relevant terminology see ``tensor_can.py``
  1138. section of the combinatorics module.
  1139. Parameters
  1140. ==========
  1141. bsgs : tuple ``(base, sgs)`` BSGS of the symmetry of the tensor
  1142. Attributes
  1143. ==========
  1144. ``base`` : base of the BSGS
  1145. ``generators`` : generators of the BSGS
  1146. ``rank`` : rank of the tensor
  1147. Notes
  1148. =====
  1149. A tensor can have an arbitrary monoterm symmetry provided by its BSGS.
  1150. Multiterm symmetries, like the cyclic symmetry of the Riemann tensor
  1151. (i.e., Bianchi identity), are not covered. See combinatorics module for
  1152. information on how to generate BSGS for a general index permutation group.
  1153. Simple symmetries can be generated using built-in methods.
  1154. See Also
  1155. ========
  1156. sympy.combinatorics.tensor_can.get_symmetric_group_sgs
  1157. Examples
  1158. ========
  1159. Define a symmetric tensor of rank 2
  1160. >>> from sympy.tensor.tensor import TensorIndexType, TensorSymmetry, get_symmetric_group_sgs, TensorHead
  1161. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  1162. >>> sym = TensorSymmetry(get_symmetric_group_sgs(2))
  1163. >>> T = TensorHead('T', [Lorentz]*2, sym)
  1164. Note, that the same can also be done using built-in TensorSymmetry methods
  1165. >>> sym2 = TensorSymmetry.fully_symmetric(2)
  1166. >>> sym == sym2
  1167. True
  1168. """
  1169. def __new__(cls, *args, **kw_args):
  1170. if len(args) == 1:
  1171. base, generators = args[0]
  1172. elif len(args) == 2:
  1173. base, generators = args
  1174. else:
  1175. raise TypeError("bsgs required, either two separate parameters or one tuple")
  1176. if not isinstance(base, Tuple):
  1177. base = Tuple(*base)
  1178. if not isinstance(generators, Tuple):
  1179. generators = Tuple(*generators)
  1180. return Basic.__new__(cls, base, generators, **kw_args)
  1181. @property
  1182. def base(self):
  1183. return self.args[0]
  1184. @property
  1185. def generators(self):
  1186. return self.args[1]
  1187. @property
  1188. def rank(self):
  1189. return self.generators[0].size - 2
  1190. @classmethod
  1191. def fully_symmetric(cls, rank):
  1192. """
  1193. Returns a fully symmetric (antisymmetric if ``rank``<0)
  1194. TensorSymmetry object for ``abs(rank)`` indices.
  1195. """
  1196. if rank > 0:
  1197. bsgs = get_symmetric_group_sgs(rank, False)
  1198. elif rank < 0:
  1199. bsgs = get_symmetric_group_sgs(-rank, True)
  1200. elif rank == 0:
  1201. bsgs = ([], [Permutation(1)])
  1202. return TensorSymmetry(bsgs)
  1203. @classmethod
  1204. def direct_product(cls, *args):
  1205. """
  1206. Returns a TensorSymmetry object that is being a direct product of
  1207. fully (anti-)symmetric index permutation groups.
  1208. Notes
  1209. =====
  1210. Some examples for different values of ``(*args)``:
  1211. ``(1)`` vector, equivalent to ``TensorSymmetry.fully_symmetric(1)``
  1212. ``(2)`` tensor with 2 symmetric indices, equivalent to ``.fully_symmetric(2)``
  1213. ``(-2)`` tensor with 2 antisymmetric indices, equivalent to ``.fully_symmetric(-2)``
  1214. ``(2, -2)`` tensor with the first 2 indices commuting and the last 2 anticommuting
  1215. ``(1, 1, 1)`` tensor with 3 indices without any symmetry
  1216. """
  1217. base, sgs = [], [Permutation(1)]
  1218. for arg in args:
  1219. if arg > 0:
  1220. bsgs2 = get_symmetric_group_sgs(arg, False)
  1221. elif arg < 0:
  1222. bsgs2 = get_symmetric_group_sgs(-arg, True)
  1223. else:
  1224. continue
  1225. base, sgs = bsgs_direct_product(base, sgs, *bsgs2)
  1226. return TensorSymmetry(base, sgs)
  1227. @classmethod
  1228. def riemann(cls):
  1229. """
  1230. Returns a monotorem symmetry of the Riemann tensor
  1231. """
  1232. return TensorSymmetry(riemann_bsgs)
  1233. @classmethod
  1234. def no_symmetry(cls, rank):
  1235. """
  1236. TensorSymmetry object for ``rank`` indices with no symmetry
  1237. """
  1238. return TensorSymmetry([], [Permutation(rank+1)])
  1239. @deprecated(
  1240. """
  1241. The tensorsymmetry() function is deprecated. Use the TensorSymmetry
  1242. constructor instead.
  1243. """,
  1244. deprecated_since_version="1.5",
  1245. active_deprecations_target="deprecated-tensorsymmetry",
  1246. )
  1247. def tensorsymmetry(*args):
  1248. """
  1249. Returns a ``TensorSymmetry`` object. This method is deprecated, use
  1250. ``TensorSymmetry.direct_product()`` or ``.riemann()`` instead.
  1251. Explanation
  1252. ===========
  1253. One can represent a tensor with any monoterm slot symmetry group
  1254. using a BSGS.
  1255. ``args`` can be a BSGS
  1256. ``args[0]`` base
  1257. ``args[1]`` sgs
  1258. Usually tensors are in (direct products of) representations
  1259. of the symmetric group;
  1260. ``args`` can be a list of lists representing the shapes of Young tableaux
  1261. Notes
  1262. =====
  1263. For instance:
  1264. ``[[1]]`` vector
  1265. ``[[1]*n]`` symmetric tensor of rank ``n``
  1266. ``[[n]]`` antisymmetric tensor of rank ``n``
  1267. ``[[2, 2]]`` monoterm slot symmetry of the Riemann tensor
  1268. ``[[1],[1]]`` vector*vector
  1269. ``[[2],[1],[1]`` (antisymmetric tensor)*vector*vector
  1270. Notice that with the shape ``[2, 2]`` we associate only the monoterm
  1271. symmetries of the Riemann tensor; this is an abuse of notation,
  1272. since the shape ``[2, 2]`` corresponds usually to the irreducible
  1273. representation characterized by the monoterm symmetries and by the
  1274. cyclic symmetry.
  1275. """
  1276. from sympy.combinatorics import Permutation
  1277. def tableau2bsgs(a):
  1278. if len(a) == 1:
  1279. # antisymmetric vector
  1280. n = a[0]
  1281. bsgs = get_symmetric_group_sgs(n, 1)
  1282. else:
  1283. if all(x == 1 for x in a):
  1284. # symmetric vector
  1285. n = len(a)
  1286. bsgs = get_symmetric_group_sgs(n)
  1287. elif a == [2, 2]:
  1288. bsgs = riemann_bsgs
  1289. else:
  1290. raise NotImplementedError
  1291. return bsgs
  1292. if not args:
  1293. return TensorSymmetry(Tuple(), Tuple(Permutation(1)))
  1294. if len(args) == 2 and isinstance(args[1][0], Permutation):
  1295. return TensorSymmetry(args)
  1296. base, sgs = tableau2bsgs(args[0])
  1297. for a in args[1:]:
  1298. basex, sgsx = tableau2bsgs(a)
  1299. base, sgs = bsgs_direct_product(base, sgs, basex, sgsx)
  1300. return TensorSymmetry(Tuple(base, sgs))
  1301. @deprecated(
  1302. "TensorType is deprecated. Use tensor_heads() instead.",
  1303. deprecated_since_version="1.5",
  1304. active_deprecations_target="deprecated-tensortype",
  1305. )
  1306. class TensorType(Basic):
  1307. """
  1308. Class of tensor types. Deprecated, use tensor_heads() instead.
  1309. Parameters
  1310. ==========
  1311. index_types : list of ``TensorIndexType`` of the tensor indices
  1312. symmetry : ``TensorSymmetry`` of the tensor
  1313. Attributes
  1314. ==========
  1315. ``index_types``
  1316. ``symmetry``
  1317. ``types`` : list of ``TensorIndexType`` without repetitions
  1318. """
  1319. is_commutative = False
  1320. def __new__(cls, index_types, symmetry, **kw_args):
  1321. assert symmetry.rank == len(index_types)
  1322. obj = Basic.__new__(cls, Tuple(*index_types), symmetry, **kw_args)
  1323. return obj
  1324. @property
  1325. def index_types(self):
  1326. return self.args[0]
  1327. @property
  1328. def symmetry(self):
  1329. return self.args[1]
  1330. @property
  1331. def types(self):
  1332. return sorted(set(self.index_types), key=lambda x: x.name)
  1333. def __str__(self):
  1334. return 'TensorType(%s)' % ([str(x) for x in self.index_types])
  1335. def __call__(self, s, comm=0):
  1336. """
  1337. Return a TensorHead object or a list of TensorHead objects.
  1338. Parameters
  1339. ==========
  1340. s : name or string of names.
  1341. comm : Commutation group.
  1342. see ``_TensorManager.set_comm``
  1343. """
  1344. if isinstance(s, str):
  1345. names = [x.name for x in symbols(s, seq=True)]
  1346. else:
  1347. raise ValueError('expecting a string')
  1348. if len(names) == 1:
  1349. return TensorHead(names[0], self.index_types, self.symmetry, comm)
  1350. else:
  1351. return [TensorHead(name, self.index_types, self.symmetry, comm) for name in names]
  1352. @deprecated(
  1353. """
  1354. The tensorhead() function is deprecated. Use tensor_heads() instead.
  1355. """,
  1356. deprecated_since_version="1.5",
  1357. active_deprecations_target="deprecated-tensorhead",
  1358. )
  1359. def tensorhead(name, typ, sym=None, comm=0):
  1360. """
  1361. Function generating tensorhead(s). This method is deprecated,
  1362. use TensorHead constructor or tensor_heads() instead.
  1363. Parameters
  1364. ==========
  1365. name : name or sequence of names (as in ``symbols``)
  1366. typ : index types
  1367. sym : same as ``*args`` in ``tensorsymmetry``
  1368. comm : commutation group number
  1369. see ``_TensorManager.set_comm``
  1370. """
  1371. if sym is None:
  1372. sym = [[1] for i in range(len(typ))]
  1373. with ignore_warnings(SymPyDeprecationWarning):
  1374. sym = tensorsymmetry(*sym)
  1375. return TensorHead(name, typ, sym, comm)
  1376. class TensorHead(Basic):
  1377. """
  1378. Tensor head of the tensor.
  1379. Parameters
  1380. ==========
  1381. name : name of the tensor
  1382. index_types : list of TensorIndexType
  1383. symmetry : TensorSymmetry of the tensor
  1384. comm : commutation group number
  1385. Attributes
  1386. ==========
  1387. ``name``
  1388. ``index_types``
  1389. ``rank`` : total number of indices
  1390. ``symmetry``
  1391. ``comm`` : commutation group
  1392. Notes
  1393. =====
  1394. Similar to ``symbols`` multiple TensorHeads can be created using
  1395. ``tensorhead(s, typ, sym=None, comm=0)`` function, where ``s``
  1396. is the string of names and ``sym`` is the monoterm tensor symmetry
  1397. (see ``tensorsymmetry``).
  1398. A ``TensorHead`` belongs to a commutation group, defined by a
  1399. symbol on number ``comm`` (see ``_TensorManager.set_comm``);
  1400. tensors in a commutation group have the same commutation properties;
  1401. by default ``comm`` is ``0``, the group of the commuting tensors.
  1402. Examples
  1403. ========
  1404. Define a fully antisymmetric tensor of rank 2:
  1405. >>> from sympy.tensor.tensor import TensorIndexType, TensorHead, TensorSymmetry
  1406. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  1407. >>> asym2 = TensorSymmetry.fully_symmetric(-2)
  1408. >>> A = TensorHead('A', [Lorentz, Lorentz], asym2)
  1409. Examples with ndarray values, the components data assigned to the
  1410. ``TensorHead`` object are assumed to be in a fully-contravariant
  1411. representation. In case it is necessary to assign components data which
  1412. represents the values of a non-fully covariant tensor, see the other
  1413. examples.
  1414. >>> from sympy.tensor.tensor import tensor_indices
  1415. >>> from sympy import diag
  1416. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  1417. >>> i0, i1 = tensor_indices('i0:2', Lorentz)
  1418. Specify a replacement dictionary to keep track of the arrays to use for
  1419. replacements in the tensorial expression. The ``TensorIndexType`` is
  1420. associated to the metric used for contractions (in fully covariant form):
  1421. >>> repl = {Lorentz: diag(1, -1, -1, -1)}
  1422. Let's see some examples of working with components with the electromagnetic
  1423. tensor:
  1424. >>> from sympy import symbols
  1425. >>> Ex, Ey, Ez, Bx, By, Bz = symbols('E_x E_y E_z B_x B_y B_z')
  1426. >>> c = symbols('c', positive=True)
  1427. Let's define `F`, an antisymmetric tensor:
  1428. >>> F = TensorHead('F', [Lorentz, Lorentz], asym2)
  1429. Let's update the dictionary to contain the matrix to use in the
  1430. replacements:
  1431. >>> repl.update({F(-i0, -i1): [
  1432. ... [0, Ex/c, Ey/c, Ez/c],
  1433. ... [-Ex/c, 0, -Bz, By],
  1434. ... [-Ey/c, Bz, 0, -Bx],
  1435. ... [-Ez/c, -By, Bx, 0]]})
  1436. Now it is possible to retrieve the contravariant form of the Electromagnetic
  1437. tensor:
  1438. >>> F(i0, i1).replace_with_arrays(repl, [i0, i1])
  1439. [[0, -E_x/c, -E_y/c, -E_z/c], [E_x/c, 0, -B_z, B_y], [E_y/c, B_z, 0, -B_x], [E_z/c, -B_y, B_x, 0]]
  1440. and the mixed contravariant-covariant form:
  1441. >>> F(i0, -i1).replace_with_arrays(repl, [i0, -i1])
  1442. [[0, E_x/c, E_y/c, E_z/c], [E_x/c, 0, B_z, -B_y], [E_y/c, -B_z, 0, B_x], [E_z/c, B_y, -B_x, 0]]
  1443. Energy-momentum of a particle may be represented as:
  1444. >>> from sympy import symbols
  1445. >>> P = TensorHead('P', [Lorentz], TensorSymmetry.no_symmetry(1))
  1446. >>> E, px, py, pz = symbols('E p_x p_y p_z', positive=True)
  1447. >>> repl.update({P(i0): [E, px, py, pz]})
  1448. The contravariant and covariant components are, respectively:
  1449. >>> P(i0).replace_with_arrays(repl, [i0])
  1450. [E, p_x, p_y, p_z]
  1451. >>> P(-i0).replace_with_arrays(repl, [-i0])
  1452. [E, -p_x, -p_y, -p_z]
  1453. The contraction of a 1-index tensor by itself:
  1454. >>> expr = P(i0)*P(-i0)
  1455. >>> expr.replace_with_arrays(repl, [])
  1456. E**2 - p_x**2 - p_y**2 - p_z**2
  1457. """
  1458. is_commutative = False
  1459. def __new__(cls, name, index_types, symmetry=None, comm=0):
  1460. if isinstance(name, str):
  1461. name_symbol = Symbol(name)
  1462. elif isinstance(name, Symbol):
  1463. name_symbol = name
  1464. else:
  1465. raise ValueError("invalid name")
  1466. if symmetry is None:
  1467. symmetry = TensorSymmetry.no_symmetry(len(index_types))
  1468. else:
  1469. assert symmetry.rank == len(index_types)
  1470. obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), symmetry)
  1471. obj.comm = TensorManager.comm_symbols2i(comm)
  1472. return obj
  1473. @property
  1474. def name(self):
  1475. return self.args[0].name
  1476. @property
  1477. def index_types(self):
  1478. return list(self.args[1])
  1479. @property
  1480. def symmetry(self):
  1481. return self.args[2]
  1482. @property
  1483. def rank(self):
  1484. return len(self.index_types)
  1485. def __lt__(self, other):
  1486. return (self.name, self.index_types) < (other.name, other.index_types)
  1487. def commutes_with(self, other):
  1488. """
  1489. Returns ``0`` if ``self`` and ``other`` commute, ``1`` if they anticommute.
  1490. Returns ``None`` if ``self`` and ``other`` neither commute nor anticommute.
  1491. """
  1492. r = TensorManager.get_comm(self.comm, other.comm)
  1493. return r
  1494. def _print(self):
  1495. return '%s(%s)' %(self.name, ','.join([str(x) for x in self.index_types]))
  1496. def __call__(self, *indices, **kw_args):
  1497. """
  1498. Returns a tensor with indices.
  1499. Explanation
  1500. ===========
  1501. There is a special behavior in case of indices denoted by ``True``,
  1502. they are considered auto-matrix indices, their slots are automatically
  1503. filled, and confer to the tensor the behavior of a matrix or vector
  1504. upon multiplication with another tensor containing auto-matrix indices
  1505. of the same ``TensorIndexType``. This means indices get summed over the
  1506. same way as in matrix multiplication. For matrix behavior, define two
  1507. auto-matrix indices, for vector behavior define just one.
  1508. Indices can also be strings, in which case the attribute
  1509. ``index_types`` is used to convert them to proper ``TensorIndex``.
  1510. Examples
  1511. ========
  1512. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorSymmetry, TensorHead
  1513. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  1514. >>> a, b = tensor_indices('a,b', Lorentz)
  1515. >>> A = TensorHead('A', [Lorentz]*2, TensorSymmetry.no_symmetry(2))
  1516. >>> t = A(a, -b)
  1517. >>> t
  1518. A(a, -b)
  1519. """
  1520. updated_indices = []
  1521. for idx, typ in zip(indices, self.index_types):
  1522. if isinstance(idx, str):
  1523. idx = idx.strip().replace(" ", "")
  1524. if idx.startswith('-'):
  1525. updated_indices.append(TensorIndex(idx[1:], typ,
  1526. is_up=False))
  1527. else:
  1528. updated_indices.append(TensorIndex(idx, typ))
  1529. else:
  1530. updated_indices.append(idx)
  1531. updated_indices += indices[len(updated_indices):]
  1532. tensor = Tensor(self, updated_indices, **kw_args)
  1533. return tensor.doit()
  1534. # Everything below this line is deprecated
  1535. def __pow__(self, other):
  1536. deprecate_data()
  1537. with ignore_warnings(SymPyDeprecationWarning):
  1538. if self.data is None:
  1539. raise ValueError("No power on abstract tensors.")
  1540. from .array import tensorproduct, tensorcontraction
  1541. metrics = [_.data for _ in self.index_types]
  1542. marray = self.data
  1543. marraydim = marray.rank()
  1544. for metric in metrics:
  1545. marray = tensorproduct(marray, metric, marray)
  1546. marray = tensorcontraction(marray, (0, marraydim), (marraydim+1, marraydim+2))
  1547. return marray ** (other * S.Half)
  1548. @property
  1549. def data(self):
  1550. deprecate_data()
  1551. with ignore_warnings(SymPyDeprecationWarning):
  1552. return _tensor_data_substitution_dict[self]
  1553. @data.setter
  1554. def data(self, data):
  1555. deprecate_data()
  1556. with ignore_warnings(SymPyDeprecationWarning):
  1557. _tensor_data_substitution_dict[self] = data
  1558. @data.deleter
  1559. def data(self):
  1560. deprecate_data()
  1561. if self in _tensor_data_substitution_dict:
  1562. del _tensor_data_substitution_dict[self]
  1563. def __iter__(self):
  1564. deprecate_data()
  1565. with ignore_warnings(SymPyDeprecationWarning):
  1566. return self.data.__iter__()
  1567. def _components_data_full_destroy(self):
  1568. """
  1569. EXPERIMENTAL: do not rely on this API method.
  1570. Destroy components data associated to the ``TensorHead`` object, this
  1571. checks for attached components data, and destroys components data too.
  1572. """
  1573. # do not garbage collect Kronecker tensor (it should be done by
  1574. # ``TensorIndexType`` garbage collection)
  1575. deprecate_data()
  1576. if self.name == "KD":
  1577. return
  1578. # the data attached to a tensor must be deleted only by the TensorHead
  1579. # destructor. If the TensorHead is deleted, it means that there are no
  1580. # more instances of that tensor anywhere.
  1581. if self in _tensor_data_substitution_dict:
  1582. del _tensor_data_substitution_dict[self]
  1583. def tensor_heads(s, index_types, symmetry=None, comm=0):
  1584. """
  1585. Returns a sequence of TensorHeads from a string `s`
  1586. """
  1587. if isinstance(s, str):
  1588. names = [x.name for x in symbols(s, seq=True)]
  1589. else:
  1590. raise ValueError('expecting a string')
  1591. thlist = [TensorHead(name, index_types, symmetry, comm) for name in names]
  1592. if len(thlist) == 1:
  1593. return thlist[0]
  1594. return thlist
  1595. class TensExpr(Expr, ABC):
  1596. """
  1597. Abstract base class for tensor expressions
  1598. Notes
  1599. =====
  1600. A tensor expression is an expression formed by tensors;
  1601. currently the sums of tensors are distributed.
  1602. A ``TensExpr`` can be a ``TensAdd`` or a ``TensMul``.
  1603. ``TensMul`` objects are formed by products of component tensors,
  1604. and include a coefficient, which is a SymPy expression.
  1605. In the internal representation contracted indices are represented
  1606. by ``(ipos1, ipos2, icomp1, icomp2)``, where ``icomp1`` is the position
  1607. of the component tensor with contravariant index, ``ipos1`` is the
  1608. slot which the index occupies in that component tensor.
  1609. Contracted indices are therefore nameless in the internal representation.
  1610. """
  1611. _op_priority = 12.0
  1612. is_commutative = False
  1613. def __neg__(self):
  1614. return self*S.NegativeOne
  1615. def __abs__(self):
  1616. raise NotImplementedError
  1617. def __add__(self, other):
  1618. return TensAdd(self, other).doit()
  1619. def __radd__(self, other):
  1620. return TensAdd(other, self).doit()
  1621. def __sub__(self, other):
  1622. return TensAdd(self, -other).doit()
  1623. def __rsub__(self, other):
  1624. return TensAdd(other, -self).doit()
  1625. def __mul__(self, other):
  1626. """
  1627. Multiply two tensors using Einstein summation convention.
  1628. Explanation
  1629. ===========
  1630. If the two tensors have an index in common, one contravariant
  1631. and the other covariant, in their product the indices are summed
  1632. Examples
  1633. ========
  1634. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads
  1635. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  1636. >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz)
  1637. >>> g = Lorentz.metric
  1638. >>> p, q = tensor_heads('p,q', [Lorentz])
  1639. >>> t1 = p(m0)
  1640. >>> t2 = q(-m0)
  1641. >>> t1*t2
  1642. p(L_0)*q(-L_0)
  1643. """
  1644. return TensMul(self, other).doit()
  1645. def __rmul__(self, other):
  1646. return TensMul(other, self).doit()
  1647. def __truediv__(self, other):
  1648. other = _sympify(other)
  1649. if isinstance(other, TensExpr):
  1650. raise ValueError('cannot divide by a tensor')
  1651. return TensMul(self, S.One/other).doit()
  1652. def __rtruediv__(self, other):
  1653. raise ValueError('cannot divide by a tensor')
  1654. def __pow__(self, other):
  1655. deprecate_data()
  1656. with ignore_warnings(SymPyDeprecationWarning):
  1657. if self.data is None:
  1658. raise ValueError("No power without ndarray data.")
  1659. from .array import tensorproduct, tensorcontraction
  1660. free = self.free
  1661. marray = self.data
  1662. mdim = marray.rank()
  1663. for metric in free:
  1664. marray = tensorcontraction(
  1665. tensorproduct(
  1666. marray,
  1667. metric[0].tensor_index_type.data,
  1668. marray),
  1669. (0, mdim), (mdim+1, mdim+2)
  1670. )
  1671. return marray ** (other * S.Half)
  1672. def __rpow__(self, other):
  1673. raise NotImplementedError
  1674. @property
  1675. @abstractmethod
  1676. def nocoeff(self):
  1677. raise NotImplementedError("abstract method")
  1678. @property
  1679. @abstractmethod
  1680. def coeff(self):
  1681. raise NotImplementedError("abstract method")
  1682. @abstractmethod
  1683. def get_indices(self):
  1684. raise NotImplementedError("abstract method")
  1685. @abstractmethod
  1686. def get_free_indices(self) -> list[TensorIndex]:
  1687. raise NotImplementedError("abstract method")
  1688. @abstractmethod
  1689. def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr:
  1690. raise NotImplementedError("abstract method")
  1691. def fun_eval(self, *index_tuples):
  1692. deprecate_fun_eval()
  1693. return self.substitute_indices(*index_tuples)
  1694. def get_matrix(self):
  1695. """
  1696. DEPRECATED: do not use.
  1697. Returns ndarray components data as a matrix, if components data are
  1698. available and ndarray dimension does not exceed 2.
  1699. """
  1700. from sympy.matrices.dense import Matrix
  1701. deprecate_data()
  1702. with ignore_warnings(SymPyDeprecationWarning):
  1703. if 0 < self.rank <= 2:
  1704. rows = self.data.shape[0]
  1705. columns = self.data.shape[1] if self.rank == 2 else 1
  1706. if self.rank == 2:
  1707. mat_list = [] * rows
  1708. for i in range(rows):
  1709. mat_list.append([])
  1710. for j in range(columns):
  1711. mat_list[i].append(self[i, j])
  1712. else:
  1713. mat_list = [None] * rows
  1714. for i in range(rows):
  1715. mat_list[i] = self[i]
  1716. return Matrix(mat_list)
  1717. else:
  1718. raise NotImplementedError(
  1719. "missing multidimensional reduction to matrix.")
  1720. @staticmethod
  1721. def _get_indices_permutation(indices1, indices2):
  1722. return [indices1.index(i) for i in indices2]
  1723. def expand(self, **hints):
  1724. return _expand(self, **hints).doit()
  1725. def _expand(self, **kwargs):
  1726. return self
  1727. def _get_free_indices_set(self):
  1728. indset = set()
  1729. for arg in self.args:
  1730. if isinstance(arg, TensExpr):
  1731. indset.update(arg._get_free_indices_set())
  1732. return indset
  1733. def _get_dummy_indices_set(self):
  1734. indset = set()
  1735. for arg in self.args:
  1736. if isinstance(arg, TensExpr):
  1737. indset.update(arg._get_dummy_indices_set())
  1738. return indset
  1739. def _get_indices_set(self):
  1740. indset = set()
  1741. for arg in self.args:
  1742. if isinstance(arg, TensExpr):
  1743. indset.update(arg._get_indices_set())
  1744. return indset
  1745. @property
  1746. def _iterate_dummy_indices(self):
  1747. dummy_set = self._get_dummy_indices_set()
  1748. def recursor(expr, pos):
  1749. if isinstance(expr, TensorIndex):
  1750. if expr in dummy_set:
  1751. yield (expr, pos)
  1752. elif isinstance(expr, (Tuple, TensExpr)):
  1753. for p, arg in enumerate(expr.args):
  1754. yield from recursor(arg, pos+(p,))
  1755. return recursor(self, ())
  1756. @property
  1757. def _iterate_free_indices(self):
  1758. free_set = self._get_free_indices_set()
  1759. def recursor(expr, pos):
  1760. if isinstance(expr, TensorIndex):
  1761. if expr in free_set:
  1762. yield (expr, pos)
  1763. elif isinstance(expr, (Tuple, TensExpr)):
  1764. for p, arg in enumerate(expr.args):
  1765. yield from recursor(arg, pos+(p,))
  1766. return recursor(self, ())
  1767. @property
  1768. def _iterate_indices(self):
  1769. def recursor(expr, pos):
  1770. if isinstance(expr, TensorIndex):
  1771. yield (expr, pos)
  1772. elif isinstance(expr, (Tuple, TensExpr)):
  1773. for p, arg in enumerate(expr.args):
  1774. yield from recursor(arg, pos+(p,))
  1775. return recursor(self, ())
  1776. @staticmethod
  1777. def _contract_and_permute_with_metric(metric, array, pos, dim):
  1778. # TODO: add possibility of metric after (spinors)
  1779. from .array import tensorcontraction, tensorproduct, permutedims
  1780. array = tensorcontraction(tensorproduct(metric, array), (1, 2+pos))
  1781. permu = list(range(dim))
  1782. permu[0], permu[pos] = permu[pos], permu[0]
  1783. return permutedims(array, permu)
  1784. @staticmethod
  1785. def _match_indices_with_other_tensor(array, free_ind1, free_ind2, replacement_dict):
  1786. from .array import permutedims
  1787. index_types1 = [i.tensor_index_type for i in free_ind1]
  1788. # Check if variance of indices needs to be fixed:
  1789. pos2up = []
  1790. pos2down = []
  1791. free2remaining = free_ind2[:]
  1792. for pos1, index1 in enumerate(free_ind1):
  1793. if index1 in free2remaining:
  1794. pos2 = free2remaining.index(index1)
  1795. free2remaining[pos2] = None
  1796. continue
  1797. if -index1 in free2remaining:
  1798. pos2 = free2remaining.index(-index1)
  1799. free2remaining[pos2] = None
  1800. free_ind2[pos2] = index1
  1801. if index1.is_up:
  1802. pos2up.append(pos2)
  1803. else:
  1804. pos2down.append(pos2)
  1805. else:
  1806. index2 = free2remaining[pos1]
  1807. if index2 is None:
  1808. raise ValueError("incompatible indices: %s and %s" % (free_ind1, free_ind2))
  1809. free2remaining[pos1] = None
  1810. free_ind2[pos1] = index1
  1811. if index1.is_up ^ index2.is_up:
  1812. if index1.is_up:
  1813. pos2up.append(pos1)
  1814. else:
  1815. pos2down.append(pos1)
  1816. if len(set(free_ind1) & set(free_ind2)) < len(free_ind1):
  1817. raise ValueError("incompatible indices: %s and %s" % (free_ind1, free_ind2))
  1818. # Raise indices:
  1819. for pos in pos2up:
  1820. index_type_pos = index_types1[pos]
  1821. if index_type_pos not in replacement_dict:
  1822. raise ValueError("No metric provided to lower index")
  1823. metric = replacement_dict[index_type_pos]
  1824. metric_inverse = _TensorDataLazyEvaluator.inverse_matrix(metric)
  1825. array = TensExpr._contract_and_permute_with_metric(metric_inverse, array, pos, len(free_ind1))
  1826. # Lower indices:
  1827. for pos in pos2down:
  1828. index_type_pos = index_types1[pos]
  1829. if index_type_pos not in replacement_dict:
  1830. raise ValueError("No metric provided to lower index")
  1831. metric = replacement_dict[index_type_pos]
  1832. array = TensExpr._contract_and_permute_with_metric(metric, array, pos, len(free_ind1))
  1833. if free_ind1:
  1834. permutation = TensExpr._get_indices_permutation(free_ind2, free_ind1)
  1835. array = permutedims(array, permutation)
  1836. if hasattr(array, "rank") and array.rank() == 0:
  1837. array = array[()]
  1838. return free_ind2, array
  1839. def replace_with_arrays(self, replacement_dict, indices=None):
  1840. """
  1841. Replace the tensorial expressions with arrays. The final array will
  1842. correspond to the N-dimensional array with indices arranged according
  1843. to ``indices``.
  1844. Parameters
  1845. ==========
  1846. replacement_dict
  1847. dictionary containing the replacement rules for tensors.
  1848. indices
  1849. the index order with respect to which the array is read. The
  1850. original index order will be used if no value is passed.
  1851. Examples
  1852. ========
  1853. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices
  1854. >>> from sympy.tensor.tensor import TensorHead
  1855. >>> from sympy import symbols, diag
  1856. >>> L = TensorIndexType("L")
  1857. >>> i, j = tensor_indices("i j", L)
  1858. >>> A = TensorHead("A", [L])
  1859. >>> A(i).replace_with_arrays({A(i): [1, 2]}, [i])
  1860. [1, 2]
  1861. Since 'indices' is optional, we can also call replace_with_arrays by
  1862. this way if no specific index order is needed:
  1863. >>> A(i).replace_with_arrays({A(i): [1, 2]})
  1864. [1, 2]
  1865. >>> expr = A(i)*A(j)
  1866. >>> expr.replace_with_arrays({A(i): [1, 2]})
  1867. [[1, 2], [2, 4]]
  1868. For contractions, specify the metric of the ``TensorIndexType``, which
  1869. in this case is ``L``, in its covariant form:
  1870. >>> expr = A(i)*A(-i)
  1871. >>> expr.replace_with_arrays({A(i): [1, 2], L: diag(1, -1)})
  1872. -3
  1873. Symmetrization of an array:
  1874. >>> H = TensorHead("H", [L, L])
  1875. >>> a, b, c, d = symbols("a b c d")
  1876. >>> expr = H(i, j)/2 + H(j, i)/2
  1877. >>> expr.replace_with_arrays({H(i, j): [[a, b], [c, d]]})
  1878. [[a, b/2 + c/2], [b/2 + c/2, d]]
  1879. Anti-symmetrization of an array:
  1880. >>> expr = H(i, j)/2 - H(j, i)/2
  1881. >>> repl = {H(i, j): [[a, b], [c, d]]}
  1882. >>> expr.replace_with_arrays(repl)
  1883. [[0, b/2 - c/2], [-b/2 + c/2, 0]]
  1884. The same expression can be read as the transpose by inverting ``i`` and
  1885. ``j``:
  1886. >>> expr.replace_with_arrays(repl, [j, i])
  1887. [[0, -b/2 + c/2], [b/2 - c/2, 0]]
  1888. """
  1889. from .array import Array
  1890. indices = indices or []
  1891. remap = {k.args[0] if k.is_up else -k.args[0]: k for k in self.get_free_indices()}
  1892. for i, index in enumerate(indices):
  1893. if isinstance(index, (Symbol, Mul)):
  1894. if index in remap:
  1895. indices[i] = remap[index]
  1896. else:
  1897. indices[i] = -remap[-index]
  1898. replacement_dict = {tensor: Array(array) for tensor, array in replacement_dict.items()}
  1899. # Check dimensions of replaced arrays:
  1900. for tensor, array in replacement_dict.items():
  1901. if isinstance(tensor, TensorIndexType):
  1902. expected_shape = [tensor.dim for i in range(2)]
  1903. else:
  1904. expected_shape = [index_type.dim for index_type in tensor.index_types]
  1905. if len(expected_shape) != array.rank() or (not all(dim1 == dim2 if
  1906. dim1.is_number else True for dim1, dim2 in zip(expected_shape,
  1907. array.shape))):
  1908. raise ValueError("shapes for tensor %s expected to be %s, "\
  1909. "replacement array shape is %s" % (tensor, expected_shape,
  1910. array.shape))
  1911. ret_indices, array = self._extract_data(replacement_dict)
  1912. last_indices, array = self._match_indices_with_other_tensor(array, indices, ret_indices, replacement_dict)
  1913. return array
  1914. def _check_add_Sum(self, expr, index_symbols):
  1915. from sympy.concrete.summations import Sum
  1916. indices = self.get_indices()
  1917. dum = self.dum
  1918. sum_indices = [ (index_symbols[i], 0,
  1919. indices[i].tensor_index_type.dim-1) for i, j in dum]
  1920. if sum_indices:
  1921. expr = Sum(expr, *sum_indices)
  1922. return expr
  1923. def _expand_partial_derivative(self):
  1924. # simply delegate the _expand_partial_derivative() to
  1925. # its arguments to expand a possibly found PartialDerivative
  1926. return self.func(*[
  1927. a._expand_partial_derivative()
  1928. if isinstance(a, TensExpr) else a
  1929. for a in self.args])
  1930. class TensAdd(TensExpr, AssocOp):
  1931. """
  1932. Sum of tensors.
  1933. Parameters
  1934. ==========
  1935. free_args : list of the free indices
  1936. Attributes
  1937. ==========
  1938. ``args`` : tuple of addends
  1939. ``rank`` : rank of the tensor
  1940. ``free_args`` : list of the free indices in sorted order
  1941. Examples
  1942. ========
  1943. >>> from sympy.tensor.tensor import TensorIndexType, tensor_heads, tensor_indices
  1944. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  1945. >>> a, b = tensor_indices('a,b', Lorentz)
  1946. >>> p, q = tensor_heads('p,q', [Lorentz])
  1947. >>> t = p(a) + q(a); t
  1948. p(a) + q(a)
  1949. Examples with components data added to the tensor expression:
  1950. >>> from sympy import symbols, diag
  1951. >>> x, y, z, t = symbols("x y z t")
  1952. >>> repl = {}
  1953. >>> repl[Lorentz] = diag(1, -1, -1, -1)
  1954. >>> repl[p(a)] = [1, 2, 3, 4]
  1955. >>> repl[q(a)] = [x, y, z, t]
  1956. The following are: 2**2 - 3**2 - 2**2 - 7**2 ==> -58
  1957. >>> expr = p(a) + q(a)
  1958. >>> expr.replace_with_arrays(repl, [a])
  1959. [x + 1, y + 2, z + 3, t + 4]
  1960. """
  1961. def __new__(cls, *args, **kw_args):
  1962. args = [_sympify(x) for x in args if x]
  1963. args = TensAdd._tensAdd_flatten(args)
  1964. args.sort(key=default_sort_key)
  1965. if not args:
  1966. return S.Zero
  1967. if len(args) == 1:
  1968. return args[0]
  1969. return Basic.__new__(cls, *args, **kw_args)
  1970. @property
  1971. def coeff(self):
  1972. return S.One
  1973. @property
  1974. def nocoeff(self):
  1975. return self
  1976. def get_free_indices(self) -> list[TensorIndex]:
  1977. return self.free_indices
  1978. def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr:
  1979. newargs = [arg._replace_indices(repl) if isinstance(arg, TensExpr) else arg for arg in self.args]
  1980. return self.func(*newargs)
  1981. @memoize_property
  1982. def rank(self):
  1983. if isinstance(self.args[0], TensExpr):
  1984. return self.args[0].rank
  1985. else:
  1986. return 0
  1987. @memoize_property
  1988. def free_args(self):
  1989. if isinstance(self.args[0], TensExpr):
  1990. return self.args[0].free_args
  1991. else:
  1992. return []
  1993. @memoize_property
  1994. def free_indices(self):
  1995. if isinstance(self.args[0], TensExpr):
  1996. return self.args[0].get_free_indices()
  1997. else:
  1998. return set()
  1999. def doit(self, **hints):
  2000. deep = hints.get('deep', True)
  2001. if deep:
  2002. args = [arg.doit(**hints) for arg in self.args]
  2003. else:
  2004. args = self.args
  2005. # if any of the args are zero (after doit), drop them. Otherwise, _tensAdd_check will complain about non-matching indices, even though the TensAdd is correctly formed.
  2006. args = [arg for arg in args if arg != S.Zero]
  2007. if len(args) == 0:
  2008. return S.Zero
  2009. elif len(args) == 1:
  2010. return args[0]
  2011. # now check that all addends have the same indices:
  2012. TensAdd._tensAdd_check(args)
  2013. # Collect terms appearing more than once, differing by their coefficients:
  2014. args = TensAdd._tensAdd_collect_terms(args)
  2015. # collect canonicalized terms
  2016. def sort_key(t):
  2017. if not isinstance(t, TensExpr):
  2018. return [], [], []
  2019. if hasattr(t, "_index_structure") and hasattr(t, "components"):
  2020. x = get_index_structure(t)
  2021. return t.components, x.free, x.dum
  2022. return [], [], []
  2023. args.sort(key=sort_key)
  2024. if not args:
  2025. return S.Zero
  2026. # it there is only a component tensor return it
  2027. if len(args) == 1:
  2028. return args[0]
  2029. obj = self.func(*args)
  2030. return obj
  2031. @staticmethod
  2032. def _tensAdd_flatten(args):
  2033. # flatten TensAdd, coerce terms which are not tensors to tensors
  2034. a = []
  2035. for x in args:
  2036. if isinstance(x, (Add, TensAdd)):
  2037. a.extend(list(x.args))
  2038. else:
  2039. a.append(x)
  2040. args = [x for x in a if x.coeff]
  2041. return args
  2042. @staticmethod
  2043. def _tensAdd_check(args):
  2044. # check that all addends have the same free indices
  2045. def get_indices_set(x: Expr) -> set[TensorIndex]:
  2046. if isinstance(x, TensExpr):
  2047. return set(x.get_free_indices())
  2048. return set()
  2049. indices0 = get_indices_set(args[0])
  2050. list_indices = [get_indices_set(arg) for arg in args[1:]]
  2051. if not all(x == indices0 for x in list_indices):
  2052. raise ValueError('all tensors must have the same indices')
  2053. @staticmethod
  2054. def _tensAdd_collect_terms(args):
  2055. # collect TensMul terms differing at most by their coefficient
  2056. terms_dict = defaultdict(list)
  2057. scalars = S.Zero
  2058. if isinstance(args[0], TensExpr):
  2059. free_indices = set(args[0].get_free_indices())
  2060. else:
  2061. free_indices = set()
  2062. for arg in args:
  2063. if not isinstance(arg, TensExpr):
  2064. if free_indices != set():
  2065. raise ValueError("wrong valence")
  2066. scalars += arg
  2067. continue
  2068. if free_indices != set(arg.get_free_indices()):
  2069. raise ValueError("wrong valence")
  2070. # TODO: what is the part which is not a coeff?
  2071. # needs an implementation similar to .as_coeff_Mul()
  2072. terms_dict[arg.nocoeff].append(arg.coeff)
  2073. new_args = [TensMul(Add(*coeff), t).doit() for t, coeff in terms_dict.items() if Add(*coeff) != 0]
  2074. if isinstance(scalars, Add):
  2075. new_args = list(scalars.args) + new_args
  2076. elif scalars != 0:
  2077. new_args = [scalars] + new_args
  2078. return new_args
  2079. def get_indices(self):
  2080. indices = []
  2081. for arg in self.args:
  2082. indices.extend([i for i in get_indices(arg) if i not in indices])
  2083. return indices
  2084. def _expand(self, **hints):
  2085. return TensAdd(*[_expand(i, **hints) for i in self.args])
  2086. def __call__(self, *indices):
  2087. deprecate_call()
  2088. free_args = self.free_args
  2089. indices = list(indices)
  2090. if [x.tensor_index_type for x in indices] != [x.tensor_index_type for x in free_args]:
  2091. raise ValueError('incompatible types')
  2092. if indices == free_args:
  2093. return self
  2094. index_tuples = list(zip(free_args, indices))
  2095. a = [x.func(*x.substitute_indices(*index_tuples).args) for x in self.args]
  2096. res = TensAdd(*a).doit()
  2097. return res
  2098. def canon_bp(self):
  2099. """
  2100. Canonicalize using the Butler-Portugal algorithm for canonicalization
  2101. under monoterm symmetries.
  2102. """
  2103. expr = self.expand()
  2104. args = [canon_bp(x) for x in expr.args]
  2105. res = TensAdd(*args).doit()
  2106. return res
  2107. def equals(self, other):
  2108. other = _sympify(other)
  2109. if isinstance(other, TensMul) and other.coeff == 0:
  2110. return all(x.coeff == 0 for x in self.args)
  2111. if isinstance(other, TensExpr):
  2112. if self.rank != other.rank:
  2113. return False
  2114. if isinstance(other, TensAdd):
  2115. if set(self.args) != set(other.args):
  2116. return False
  2117. else:
  2118. return True
  2119. t = self - other
  2120. if not isinstance(t, TensExpr):
  2121. return t == 0
  2122. else:
  2123. if isinstance(t, TensMul):
  2124. return t.coeff == 0
  2125. else:
  2126. return all(x.coeff == 0 for x in t.args)
  2127. def __getitem__(self, item):
  2128. deprecate_data()
  2129. with ignore_warnings(SymPyDeprecationWarning):
  2130. return self.data[item]
  2131. def contract_delta(self, delta):
  2132. args = [x.contract_delta(delta) for x in self.args]
  2133. t = TensAdd(*args).doit()
  2134. return canon_bp(t)
  2135. def contract_metric(self, g):
  2136. """
  2137. Raise or lower indices with the metric ``g``.
  2138. Parameters
  2139. ==========
  2140. g : metric
  2141. contract_all : if True, eliminate all ``g`` which are contracted
  2142. Notes
  2143. =====
  2144. see the ``TensorIndexType`` docstring for the contraction conventions
  2145. """
  2146. args = [contract_metric(x, g) for x in self.args]
  2147. t = TensAdd(*args).doit()
  2148. return canon_bp(t)
  2149. def substitute_indices(self, *index_tuples):
  2150. new_args = []
  2151. for arg in self.args:
  2152. if isinstance(arg, TensExpr):
  2153. arg = arg.substitute_indices(*index_tuples)
  2154. new_args.append(arg)
  2155. return TensAdd(*new_args).doit()
  2156. def _print(self):
  2157. a = []
  2158. args = self.args
  2159. for x in args:
  2160. a.append(str(x))
  2161. s = ' + '.join(a)
  2162. s = s.replace('+ -', '- ')
  2163. return s
  2164. def _extract_data(self, replacement_dict):
  2165. from sympy.tensor.array import Array, permutedims
  2166. args_indices, arrays = zip(*[
  2167. arg._extract_data(replacement_dict) if
  2168. isinstance(arg, TensExpr) else ([], arg) for arg in self.args
  2169. ])
  2170. arrays = [Array(i) for i in arrays]
  2171. ref_indices = args_indices[0]
  2172. for i in range(1, len(args_indices)):
  2173. indices = args_indices[i]
  2174. array = arrays[i]
  2175. permutation = TensMul._get_indices_permutation(indices, ref_indices)
  2176. arrays[i] = permutedims(array, permutation)
  2177. return ref_indices, sum(arrays, Array.zeros(*array.shape))
  2178. @property
  2179. def data(self):
  2180. deprecate_data()
  2181. with ignore_warnings(SymPyDeprecationWarning):
  2182. return _tensor_data_substitution_dict[self.expand()]
  2183. @data.setter
  2184. def data(self, data):
  2185. deprecate_data()
  2186. with ignore_warnings(SymPyDeprecationWarning):
  2187. _tensor_data_substitution_dict[self] = data
  2188. @data.deleter
  2189. def data(self):
  2190. deprecate_data()
  2191. with ignore_warnings(SymPyDeprecationWarning):
  2192. if self in _tensor_data_substitution_dict:
  2193. del _tensor_data_substitution_dict[self]
  2194. def __iter__(self):
  2195. deprecate_data()
  2196. if not self.data:
  2197. raise ValueError("No iteration on abstract tensors")
  2198. return self.data.flatten().__iter__()
  2199. def _eval_rewrite_as_Indexed(self, *args):
  2200. return Add.fromiter(args)
  2201. def _eval_partial_derivative(self, s):
  2202. # Evaluation like Add
  2203. list_addends = []
  2204. for a in self.args:
  2205. if isinstance(a, TensExpr):
  2206. list_addends.append(a._eval_partial_derivative(s))
  2207. # do not call diff if s is no symbol
  2208. elif s._diff_wrt:
  2209. list_addends.append(a._eval_derivative(s))
  2210. return self.func(*list_addends)
  2211. class Tensor(TensExpr):
  2212. """
  2213. Base tensor class, i.e. this represents a tensor, the single unit to be
  2214. put into an expression.
  2215. Explanation
  2216. ===========
  2217. This object is usually created from a ``TensorHead``, by attaching indices
  2218. to it. Indices preceded by a minus sign are considered contravariant,
  2219. otherwise covariant.
  2220. Examples
  2221. ========
  2222. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead
  2223. >>> Lorentz = TensorIndexType("Lorentz", dummy_name="L")
  2224. >>> mu, nu = tensor_indices('mu nu', Lorentz)
  2225. >>> A = TensorHead("A", [Lorentz, Lorentz])
  2226. >>> A(mu, -nu)
  2227. A(mu, -nu)
  2228. >>> A(mu, -mu)
  2229. A(L_0, -L_0)
  2230. It is also possible to use symbols instead of inidices (appropriate indices
  2231. are then generated automatically).
  2232. >>> from sympy import Symbol
  2233. >>> x = Symbol('x')
  2234. >>> A(x, mu)
  2235. A(x, mu)
  2236. >>> A(x, -x)
  2237. A(L_0, -L_0)
  2238. """
  2239. is_commutative = False
  2240. _index_structure = None # type: _IndexStructure
  2241. args: tuple[TensorHead, Tuple]
  2242. def __new__(cls, tensor_head, indices, *, is_canon_bp=False, **kw_args):
  2243. indices = cls._parse_indices(tensor_head, indices)
  2244. obj = Basic.__new__(cls, tensor_head, Tuple(*indices), **kw_args)
  2245. obj._index_structure = _IndexStructure.from_indices(*indices)
  2246. obj._free = obj._index_structure.free[:]
  2247. obj._dum = obj._index_structure.dum[:]
  2248. obj._ext_rank = obj._index_structure._ext_rank
  2249. obj._coeff = S.One
  2250. obj._nocoeff = obj
  2251. obj._component = tensor_head
  2252. obj._components = [tensor_head]
  2253. if tensor_head.rank != len(indices):
  2254. raise ValueError("wrong number of indices")
  2255. obj.is_canon_bp = is_canon_bp
  2256. obj._index_map = Tensor._build_index_map(indices, obj._index_structure)
  2257. return obj
  2258. @property
  2259. def free(self):
  2260. return self._free
  2261. @property
  2262. def dum(self):
  2263. return self._dum
  2264. @property
  2265. def ext_rank(self):
  2266. return self._ext_rank
  2267. @property
  2268. def coeff(self):
  2269. return self._coeff
  2270. @property
  2271. def nocoeff(self):
  2272. return self._nocoeff
  2273. @property
  2274. def component(self):
  2275. return self._component
  2276. @property
  2277. def components(self):
  2278. return self._components
  2279. @property
  2280. def head(self):
  2281. return self.args[0]
  2282. @property
  2283. def indices(self):
  2284. return self.args[1]
  2285. @property
  2286. def free_indices(self):
  2287. return set(self._index_structure.get_free_indices())
  2288. @property
  2289. def index_types(self):
  2290. return self.head.index_types
  2291. @property
  2292. def rank(self):
  2293. return len(self.free_indices)
  2294. @staticmethod
  2295. def _build_index_map(indices, index_structure):
  2296. index_map = {}
  2297. for idx in indices:
  2298. index_map[idx] = (indices.index(idx),)
  2299. return index_map
  2300. def doit(self, **hints):
  2301. args, indices, free, dum = TensMul._tensMul_contract_indices([self])
  2302. return args[0]
  2303. @staticmethod
  2304. def _parse_indices(tensor_head, indices):
  2305. if not isinstance(indices, (tuple, list, Tuple)):
  2306. raise TypeError("indices should be an array, got %s" % type(indices))
  2307. indices = list(indices)
  2308. for i, index in enumerate(indices):
  2309. if isinstance(index, Symbol):
  2310. indices[i] = TensorIndex(index, tensor_head.index_types[i], True)
  2311. elif isinstance(index, Mul):
  2312. c, e = index.as_coeff_Mul()
  2313. if c == -1 and isinstance(e, Symbol):
  2314. indices[i] = TensorIndex(e, tensor_head.index_types[i], False)
  2315. else:
  2316. raise ValueError("index not understood: %s" % index)
  2317. elif not isinstance(index, TensorIndex):
  2318. raise TypeError("wrong type for index: %s is %s" % (index, type(index)))
  2319. return indices
  2320. def _set_new_index_structure(self, im, is_canon_bp=False):
  2321. indices = im.get_indices()
  2322. return self._set_indices(*indices, is_canon_bp=is_canon_bp)
  2323. def _set_indices(self, *indices, is_canon_bp=False, **kw_args):
  2324. if len(indices) != self.ext_rank:
  2325. raise ValueError("indices length mismatch")
  2326. return self.func(self.args[0], indices, is_canon_bp=is_canon_bp).doit()
  2327. def _get_free_indices_set(self):
  2328. return {i[0] for i in self._index_structure.free}
  2329. def _get_dummy_indices_set(self):
  2330. dummy_pos = set(itertools.chain(*self._index_structure.dum))
  2331. return {idx for i, idx in enumerate(self.args[1]) if i in dummy_pos}
  2332. def _get_indices_set(self):
  2333. return set(self.args[1].args)
  2334. @property
  2335. def free_in_args(self):
  2336. return [(ind, pos, 0) for ind, pos in self.free]
  2337. @property
  2338. def dum_in_args(self):
  2339. return [(p1, p2, 0, 0) for p1, p2 in self.dum]
  2340. @property
  2341. def free_args(self):
  2342. return sorted([x[0] for x in self.free])
  2343. def commutes_with(self, other):
  2344. """
  2345. :param other:
  2346. :return:
  2347. 0 commute
  2348. 1 anticommute
  2349. None neither commute nor anticommute
  2350. """
  2351. if not isinstance(other, TensExpr):
  2352. return 0
  2353. elif isinstance(other, Tensor):
  2354. return self.component.commutes_with(other.component)
  2355. return NotImplementedError
  2356. def perm2tensor(self, g, is_canon_bp=False):
  2357. """
  2358. Returns the tensor corresponding to the permutation ``g``.
  2359. For further details, see the method in ``TIDS`` with the same name.
  2360. """
  2361. return perm2tensor(self, g, is_canon_bp)
  2362. def canon_bp(self):
  2363. if self.is_canon_bp:
  2364. return self
  2365. expr = self.expand()
  2366. g, dummies, msym = expr._index_structure.indices_canon_args()
  2367. v = components_canon_args([expr.component])
  2368. can = canonicalize(g, dummies, msym, *v)
  2369. if can == 0:
  2370. return S.Zero
  2371. tensor = self.perm2tensor(can, True)
  2372. return tensor
  2373. def split(self):
  2374. return [self]
  2375. def _expand(self, **kwargs):
  2376. return self
  2377. def sorted_components(self):
  2378. return self
  2379. def get_indices(self) -> list[TensorIndex]:
  2380. """
  2381. Get a list of indices, corresponding to those of the tensor.
  2382. """
  2383. return list(self.args[1])
  2384. def get_free_indices(self) -> list[TensorIndex]:
  2385. """
  2386. Get a list of free indices, corresponding to those of the tensor.
  2387. """
  2388. return self._index_structure.get_free_indices()
  2389. def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr:
  2390. # TODO: this could be optimized by only swapping the indices
  2391. # instead of visiting the whole expression tree:
  2392. return self.xreplace(repl)
  2393. def as_base_exp(self):
  2394. return self, S.One
  2395. def substitute_indices(self, *index_tuples):
  2396. """
  2397. Return a tensor with free indices substituted according to ``index_tuples``.
  2398. ``index_types`` list of tuples ``(old_index, new_index)``.
  2399. Examples
  2400. ========
  2401. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads, TensorSymmetry
  2402. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  2403. >>> i, j, k, l = tensor_indices('i,j,k,l', Lorentz)
  2404. >>> A, B = tensor_heads('A,B', [Lorentz]*2, TensorSymmetry.fully_symmetric(2))
  2405. >>> t = A(i, k)*B(-k, -j); t
  2406. A(i, L_0)*B(-L_0, -j)
  2407. >>> t.substitute_indices((i, k),(-j, l))
  2408. A(k, L_0)*B(-L_0, l)
  2409. """
  2410. indices = []
  2411. for index in self.indices:
  2412. for ind_old, ind_new in index_tuples:
  2413. if (index.name == ind_old.name and index.tensor_index_type ==
  2414. ind_old.tensor_index_type):
  2415. if index.is_up == ind_old.is_up:
  2416. indices.append(ind_new)
  2417. else:
  2418. indices.append(-ind_new)
  2419. break
  2420. else:
  2421. indices.append(index)
  2422. return self.head(*indices)
  2423. def __call__(self, *indices):
  2424. deprecate_call()
  2425. free_args = self.free_args
  2426. indices = list(indices)
  2427. if [x.tensor_index_type for x in indices] != [x.tensor_index_type for x in free_args]:
  2428. raise ValueError('incompatible types')
  2429. if indices == free_args:
  2430. return self
  2431. t = self.substitute_indices(*list(zip(free_args, indices)))
  2432. # object is rebuilt in order to make sure that all contracted indices
  2433. # get recognized as dummies, but only if there are contracted indices.
  2434. if len({i if i.is_up else -i for i in indices}) != len(indices):
  2435. return t.func(*t.args)
  2436. return t
  2437. # TODO: put this into TensExpr?
  2438. def __iter__(self):
  2439. deprecate_data()
  2440. with ignore_warnings(SymPyDeprecationWarning):
  2441. return self.data.__iter__()
  2442. # TODO: put this into TensExpr?
  2443. def __getitem__(self, item):
  2444. deprecate_data()
  2445. with ignore_warnings(SymPyDeprecationWarning):
  2446. return self.data[item]
  2447. def _extract_data(self, replacement_dict):
  2448. from .array import Array
  2449. for k, v in replacement_dict.items():
  2450. if isinstance(k, Tensor) and k.args[0] == self.args[0]:
  2451. other = k
  2452. array = v
  2453. break
  2454. else:
  2455. raise ValueError("%s not found in %s" % (self, replacement_dict))
  2456. # TODO: inefficient, this should be done at root level only:
  2457. replacement_dict = {k: Array(v) for k, v in replacement_dict.items()}
  2458. array = Array(array)
  2459. dum1 = self.dum
  2460. dum2 = other.dum
  2461. if len(dum2) > 0:
  2462. for pair in dum2:
  2463. # allow `dum2` if the contained values are also in `dum1`.
  2464. if pair not in dum1:
  2465. raise NotImplementedError("%s with contractions is not implemented" % other)
  2466. # Remove elements in `dum2` from `dum1`:
  2467. dum1 = [pair for pair in dum1 if pair not in dum2]
  2468. if len(dum1) > 0:
  2469. indices1 = self.get_indices()
  2470. indices2 = other.get_indices()
  2471. repl = {}
  2472. for p1, p2 in dum1:
  2473. repl[indices2[p2]] = -indices2[p1]
  2474. for pos in (p1, p2):
  2475. if indices1[pos].is_up ^ indices2[pos].is_up:
  2476. metric = replacement_dict[indices1[pos].tensor_index_type]
  2477. if indices1[pos].is_up:
  2478. metric = _TensorDataLazyEvaluator.inverse_matrix(metric)
  2479. array = self._contract_and_permute_with_metric(metric, array, pos, len(indices2))
  2480. other = other.xreplace(repl).doit()
  2481. array = _TensorDataLazyEvaluator.data_contract_dum([array], dum1, len(indices2))
  2482. free_ind1 = self.get_free_indices()
  2483. free_ind2 = other.get_free_indices()
  2484. return self._match_indices_with_other_tensor(array, free_ind1, free_ind2, replacement_dict)
  2485. @property
  2486. def data(self):
  2487. deprecate_data()
  2488. with ignore_warnings(SymPyDeprecationWarning):
  2489. return _tensor_data_substitution_dict[self]
  2490. @data.setter
  2491. def data(self, data):
  2492. deprecate_data()
  2493. # TODO: check data compatibility with properties of tensor.
  2494. with ignore_warnings(SymPyDeprecationWarning):
  2495. _tensor_data_substitution_dict[self] = data
  2496. @data.deleter
  2497. def data(self):
  2498. deprecate_data()
  2499. with ignore_warnings(SymPyDeprecationWarning):
  2500. if self in _tensor_data_substitution_dict:
  2501. del _tensor_data_substitution_dict[self]
  2502. if self.metric in _tensor_data_substitution_dict:
  2503. del _tensor_data_substitution_dict[self.metric]
  2504. def _print(self):
  2505. indices = [str(ind) for ind in self.indices]
  2506. component = self.component
  2507. if component.rank > 0:
  2508. return ('%s(%s)' % (component.name, ', '.join(indices)))
  2509. else:
  2510. return ('%s' % component.name)
  2511. def equals(self, other):
  2512. if other == 0:
  2513. return self.coeff == 0
  2514. other = _sympify(other)
  2515. if not isinstance(other, TensExpr):
  2516. assert not self.components
  2517. return S.One == other
  2518. def _get_compar_comp(self):
  2519. t = self.canon_bp()
  2520. r = (t.coeff, tuple(t.components), \
  2521. tuple(sorted(t.free)), tuple(sorted(t.dum)))
  2522. return r
  2523. return _get_compar_comp(self) == _get_compar_comp(other)
  2524. def contract_metric(self, g):
  2525. # if metric is not the same, ignore this step:
  2526. if self.component != g:
  2527. return self
  2528. # in case there are free components, do not perform anything:
  2529. if len(self.free) != 0:
  2530. return self
  2531. #antisym = g.index_types[0].metric_antisym
  2532. if g.symmetry == TensorSymmetry.fully_symmetric(-2):
  2533. antisym = 1
  2534. elif g.symmetry == TensorSymmetry.fully_symmetric(2):
  2535. antisym = 0
  2536. elif g.symmetry == TensorSymmetry.no_symmetry(2):
  2537. antisym = None
  2538. else:
  2539. raise NotImplementedError
  2540. sign = S.One
  2541. typ = g.index_types[0]
  2542. if not antisym:
  2543. # g(i, -i)
  2544. sign = sign*typ.dim
  2545. else:
  2546. # g(i, -i)
  2547. sign = sign*typ.dim
  2548. dp0, dp1 = self.dum[0]
  2549. if dp0 < dp1:
  2550. # g(i, -i) = -D with antisymmetric metric
  2551. sign = -sign
  2552. return sign
  2553. def contract_delta(self, metric):
  2554. return self.contract_metric(metric)
  2555. def _eval_rewrite_as_Indexed(self, tens, indices):
  2556. from sympy.tensor.indexed import Indexed
  2557. # TODO: replace .args[0] with .name:
  2558. index_symbols = [i.args[0] for i in self.get_indices()]
  2559. expr = Indexed(tens.args[0], *index_symbols)
  2560. return self._check_add_Sum(expr, index_symbols)
  2561. def _eval_partial_derivative(self, s): # type: (Tensor) -> Expr
  2562. if not isinstance(s, Tensor):
  2563. return S.Zero
  2564. else:
  2565. # @a_i/@a_k = delta_i^k
  2566. # @a_i/@a^k = g_ij delta^j_k
  2567. # @a^i/@a^k = delta^i_k
  2568. # @a^i/@a_k = g^ij delta_j^k
  2569. # TODO: if there is no metric present, the derivative should be zero?
  2570. if self.head != s.head:
  2571. return S.Zero
  2572. # if heads are the same, provide delta and/or metric products
  2573. # for every free index pair in the appropriate tensor
  2574. # assumed that the free indices are in proper order
  2575. # A contravariante index in the derivative becomes covariant
  2576. # after performing the derivative and vice versa
  2577. kronecker_delta_list = [1]
  2578. # not guarantee a correct index order
  2579. for (count, (iself, iother)) in enumerate(zip(self.get_free_indices(), s.get_free_indices())):
  2580. if iself.tensor_index_type != iother.tensor_index_type:
  2581. raise ValueError("index types not compatible")
  2582. else:
  2583. tensor_index_type = iself.tensor_index_type
  2584. tensor_metric = tensor_index_type.metric
  2585. dummy = TensorIndex("d_" + str(count), tensor_index_type,
  2586. is_up=iself.is_up)
  2587. if iself.is_up == iother.is_up:
  2588. kroneckerdelta = tensor_index_type.delta(iself, -iother)
  2589. else:
  2590. kroneckerdelta = (
  2591. TensMul(tensor_metric(iself, dummy),
  2592. tensor_index_type.delta(-dummy, -iother))
  2593. )
  2594. kronecker_delta_list.append(kroneckerdelta)
  2595. return TensMul.fromiter(kronecker_delta_list).doit()
  2596. # doit necessary to rename dummy indices accordingly
  2597. class TensMul(TensExpr, AssocOp):
  2598. """
  2599. Product of tensors.
  2600. Parameters
  2601. ==========
  2602. coeff : SymPy coefficient of the tensor
  2603. args
  2604. Attributes
  2605. ==========
  2606. ``components`` : list of ``TensorHead`` of the component tensors
  2607. ``types`` : list of nonrepeated ``TensorIndexType``
  2608. ``free`` : list of ``(ind, ipos, icomp)``, see Notes
  2609. ``dum`` : list of ``(ipos1, ipos2, icomp1, icomp2)``, see Notes
  2610. ``ext_rank`` : rank of the tensor counting the dummy indices
  2611. ``rank`` : rank of the tensor
  2612. ``coeff`` : SymPy coefficient of the tensor
  2613. ``free_args`` : list of the free indices in sorted order
  2614. ``is_canon_bp`` : ``True`` if the tensor in in canonical form
  2615. Notes
  2616. =====
  2617. ``args[0]`` list of ``TensorHead`` of the component tensors.
  2618. ``args[1]`` list of ``(ind, ipos, icomp)``
  2619. where ``ind`` is a free index, ``ipos`` is the slot position
  2620. of ``ind`` in the ``icomp``-th component tensor.
  2621. ``args[2]`` list of tuples representing dummy indices.
  2622. ``(ipos1, ipos2, icomp1, icomp2)`` indicates that the contravariant
  2623. dummy index is the ``ipos1``-th slot position in the ``icomp1``-th
  2624. component tensor; the corresponding covariant index is
  2625. in the ``ipos2`` slot position in the ``icomp2``-th component tensor.
  2626. """
  2627. identity = S.One
  2628. _index_structure = None # type: _IndexStructure
  2629. def __new__(cls, *args, **kw_args):
  2630. is_canon_bp = kw_args.get('is_canon_bp', False)
  2631. args = list(map(_sympify, args))
  2632. """
  2633. If the internal dummy indices in one arg conflict with the free indices
  2634. of the remaining args, we need to rename those internal dummy indices.
  2635. """
  2636. free = [get_free_indices(arg) for arg in args]
  2637. free = set(itertools.chain(*free)) #flatten free
  2638. newargs = []
  2639. for arg in args:
  2640. dum_this = set(get_dummy_indices(arg))
  2641. dum_other = [get_dummy_indices(a) for a in newargs]
  2642. dum_other = set(itertools.chain(*dum_other)) #flatten dum_other
  2643. free_this = set(get_free_indices(arg))
  2644. if len(dum_this.intersection(free)) > 0:
  2645. exclude = free_this.union(free, dum_other)
  2646. newarg = TensMul._dedupe_indices(arg, exclude)
  2647. else:
  2648. newarg = arg
  2649. newargs.append(newarg)
  2650. args = newargs
  2651. # Flatten:
  2652. args = [i for arg in args for i in (arg.args if isinstance(arg, (TensMul, Mul)) else [arg])]
  2653. args, indices, free, dum = TensMul._tensMul_contract_indices(args, replace_indices=False)
  2654. # Data for indices:
  2655. index_types = [i.tensor_index_type for i in indices]
  2656. index_structure = _IndexStructure(free, dum, index_types, indices, canon_bp=is_canon_bp)
  2657. obj = TensExpr.__new__(cls, *args)
  2658. obj._indices = indices
  2659. obj._index_types = index_types[:]
  2660. obj._index_structure = index_structure
  2661. obj._free = index_structure.free[:]
  2662. obj._dum = index_structure.dum[:]
  2663. obj._free_indices = {x[0] for x in obj.free}
  2664. obj._rank = len(obj.free)
  2665. obj._ext_rank = len(obj._index_structure.free) + 2*len(obj._index_structure.dum)
  2666. obj._coeff = S.One
  2667. obj._is_canon_bp = is_canon_bp
  2668. return obj
  2669. index_types = property(lambda self: self._index_types)
  2670. free = property(lambda self: self._free)
  2671. dum = property(lambda self: self._dum)
  2672. free_indices = property(lambda self: self._free_indices)
  2673. rank = property(lambda self: self._rank)
  2674. ext_rank = property(lambda self: self._ext_rank)
  2675. @staticmethod
  2676. def _indices_to_free_dum(args_indices):
  2677. free2pos1 = {}
  2678. free2pos2 = {}
  2679. dummy_data = []
  2680. indices = []
  2681. # Notation for positions (to better understand the code):
  2682. # `pos1`: position in the `args`.
  2683. # `pos2`: position in the indices.
  2684. # Example:
  2685. # A(i, j)*B(k, m, n)*C(p)
  2686. # `pos1` of `n` is 1 because it's in `B` (second `args` of TensMul).
  2687. # `pos2` of `n` is 4 because it's the fifth overall index.
  2688. # Counter for the index position wrt the whole expression:
  2689. pos2 = 0
  2690. for pos1, arg_indices in enumerate(args_indices):
  2691. for index_pos, index in enumerate(arg_indices):
  2692. if not isinstance(index, TensorIndex):
  2693. raise TypeError("expected TensorIndex")
  2694. if -index in free2pos1:
  2695. # Dummy index detected:
  2696. other_pos1 = free2pos1.pop(-index)
  2697. other_pos2 = free2pos2.pop(-index)
  2698. if index.is_up:
  2699. dummy_data.append((index, pos1, other_pos1, pos2, other_pos2))
  2700. else:
  2701. dummy_data.append((-index, other_pos1, pos1, other_pos2, pos2))
  2702. indices.append(index)
  2703. elif index in free2pos1:
  2704. raise ValueError("Repeated index: %s" % index)
  2705. else:
  2706. free2pos1[index] = pos1
  2707. free2pos2[index] = pos2
  2708. indices.append(index)
  2709. pos2 += 1
  2710. free = [(i, p) for (i, p) in free2pos2.items()]
  2711. free_names = [i.name for i in free2pos2.keys()]
  2712. dummy_data.sort(key=lambda x: x[3])
  2713. return indices, free, free_names, dummy_data
  2714. @staticmethod
  2715. def _dummy_data_to_dum(dummy_data):
  2716. return [(p2a, p2b) for (i, p1a, p1b, p2a, p2b) in dummy_data]
  2717. @staticmethod
  2718. def _tensMul_contract_indices(args, replace_indices=True):
  2719. replacements = [{} for _ in args]
  2720. #_index_order = all(_has_index_order(arg) for arg in args)
  2721. args_indices = [get_indices(arg) for arg in args]
  2722. indices, free, free_names, dummy_data = TensMul._indices_to_free_dum(args_indices)
  2723. cdt = defaultdict(int)
  2724. def dummy_name_gen(tensor_index_type):
  2725. nd = str(cdt[tensor_index_type])
  2726. cdt[tensor_index_type] += 1
  2727. return tensor_index_type.dummy_name + '_' + nd
  2728. if replace_indices:
  2729. for old_index, pos1cov, pos1contra, pos2cov, pos2contra in dummy_data:
  2730. index_type = old_index.tensor_index_type
  2731. while True:
  2732. dummy_name = dummy_name_gen(index_type)
  2733. if dummy_name not in free_names:
  2734. break
  2735. dummy = TensorIndex(dummy_name, index_type, True)
  2736. replacements[pos1cov][old_index] = dummy
  2737. replacements[pos1contra][-old_index] = -dummy
  2738. indices[pos2cov] = dummy
  2739. indices[pos2contra] = -dummy
  2740. args = [
  2741. arg._replace_indices(repl) if isinstance(arg, TensExpr) else arg
  2742. for arg, repl in zip(args, replacements)]
  2743. dum = TensMul._dummy_data_to_dum(dummy_data)
  2744. return args, indices, free, dum
  2745. @staticmethod
  2746. def _get_components_from_args(args):
  2747. """
  2748. Get a list of ``Tensor`` objects having the same ``TIDS`` if multiplied
  2749. by one another.
  2750. """
  2751. components = []
  2752. for arg in args:
  2753. if not isinstance(arg, TensExpr):
  2754. continue
  2755. if isinstance(arg, TensAdd):
  2756. continue
  2757. components.extend(arg.components)
  2758. return components
  2759. @staticmethod
  2760. def _rebuild_tensors_list(args, index_structure):
  2761. indices = index_structure.get_indices()
  2762. #tensors = [None for i in components] # pre-allocate list
  2763. ind_pos = 0
  2764. for i, arg in enumerate(args):
  2765. if not isinstance(arg, TensExpr):
  2766. continue
  2767. prev_pos = ind_pos
  2768. ind_pos += arg.ext_rank
  2769. args[i] = Tensor(arg.component, indices[prev_pos:ind_pos])
  2770. def doit(self, **hints):
  2771. is_canon_bp = self._is_canon_bp
  2772. deep = hints.get('deep', True)
  2773. if deep:
  2774. args = [arg.doit(**hints) for arg in self.args]
  2775. """
  2776. There may now be conflicts between dummy indices of different args
  2777. (each arg's doit method does not have any information about which
  2778. dummy indices are already used in the other args), so we
  2779. deduplicate them.
  2780. """
  2781. rule = dict(zip(self.args, args))
  2782. rule = self._dedupe_indices_in_rule(rule)
  2783. args = [rule[a] for a in self.args]
  2784. else:
  2785. args = self.args
  2786. args = [arg for arg in args if arg != self.identity]
  2787. # Extract non-tensor coefficients:
  2788. coeff = reduce(lambda a, b: a*b, [arg for arg in args if not isinstance(arg, TensExpr)], S.One)
  2789. args = [arg for arg in args if isinstance(arg, TensExpr)]
  2790. if len(args) == 0:
  2791. return coeff
  2792. if coeff != self.identity:
  2793. args = [coeff] + args
  2794. if coeff == 0:
  2795. return S.Zero
  2796. if len(args) == 1:
  2797. return args[0]
  2798. args, indices, free, dum = TensMul._tensMul_contract_indices(args)
  2799. # Data for indices:
  2800. index_types = [i.tensor_index_type for i in indices]
  2801. index_structure = _IndexStructure(free, dum, index_types, indices, canon_bp=is_canon_bp)
  2802. obj = self.func(*args)
  2803. obj._index_types = index_types
  2804. obj._index_structure = index_structure
  2805. obj._ext_rank = len(obj._index_structure.free) + 2*len(obj._index_structure.dum)
  2806. obj._coeff = coeff
  2807. obj._is_canon_bp = is_canon_bp
  2808. return obj
  2809. # TODO: this method should be private
  2810. # TODO: should this method be renamed _from_components_free_dum ?
  2811. @staticmethod
  2812. def from_data(coeff, components, free, dum, **kw_args):
  2813. return TensMul(coeff, *TensMul._get_tensors_from_components_free_dum(components, free, dum), **kw_args).doit()
  2814. @staticmethod
  2815. def _get_tensors_from_components_free_dum(components, free, dum):
  2816. """
  2817. Get a list of ``Tensor`` objects by distributing ``free`` and ``dum`` indices on the ``components``.
  2818. """
  2819. index_structure = _IndexStructure.from_components_free_dum(components, free, dum)
  2820. indices = index_structure.get_indices()
  2821. tensors = [None for i in components] # pre-allocate list
  2822. # distribute indices on components to build a list of tensors:
  2823. ind_pos = 0
  2824. for i, component in enumerate(components):
  2825. prev_pos = ind_pos
  2826. ind_pos += component.rank
  2827. tensors[i] = Tensor(component, indices[prev_pos:ind_pos])
  2828. return tensors
  2829. def _get_free_indices_set(self):
  2830. return {i[0] for i in self.free}
  2831. def _get_dummy_indices_set(self):
  2832. dummy_pos = set(itertools.chain(*self.dum))
  2833. return {idx for i, idx in enumerate(self._index_structure.get_indices()) if i in dummy_pos}
  2834. def _get_position_offset_for_indices(self):
  2835. arg_offset = [None for i in range(self.ext_rank)]
  2836. counter = 0
  2837. for i, arg in enumerate(self.args):
  2838. if not isinstance(arg, TensExpr):
  2839. continue
  2840. for j in range(arg.ext_rank):
  2841. arg_offset[j + counter] = counter
  2842. counter += arg.ext_rank
  2843. return arg_offset
  2844. @property
  2845. def free_args(self):
  2846. return sorted([x[0] for x in self.free])
  2847. @property
  2848. def components(self):
  2849. return self._get_components_from_args(self.args)
  2850. @property
  2851. def free_in_args(self):
  2852. arg_offset = self._get_position_offset_for_indices()
  2853. argpos = self._get_indices_to_args_pos()
  2854. return [(ind, pos-arg_offset[pos], argpos[pos]) for (ind, pos) in self.free]
  2855. @property
  2856. def coeff(self):
  2857. # return Mul.fromiter([c for c in self.args if not isinstance(c, TensExpr)])
  2858. return self._coeff
  2859. @property
  2860. def nocoeff(self):
  2861. return self.func(*[t for t in self.args if isinstance(t, TensExpr)]).doit()
  2862. @property
  2863. def dum_in_args(self):
  2864. arg_offset = self._get_position_offset_for_indices()
  2865. argpos = self._get_indices_to_args_pos()
  2866. return [(p1-arg_offset[p1], p2-arg_offset[p2], argpos[p1], argpos[p2]) for p1, p2 in self.dum]
  2867. def equals(self, other):
  2868. if other == 0:
  2869. return self.coeff == 0
  2870. other = _sympify(other)
  2871. if not isinstance(other, TensExpr):
  2872. assert not self.components
  2873. return self.coeff == other
  2874. return self.canon_bp() == other.canon_bp()
  2875. def get_indices(self):
  2876. """
  2877. Returns the list of indices of the tensor.
  2878. Explanation
  2879. ===========
  2880. The indices are listed in the order in which they appear in the
  2881. component tensors.
  2882. The dummy indices are given a name which does not collide with
  2883. the names of the free indices.
  2884. Examples
  2885. ========
  2886. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads
  2887. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  2888. >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz)
  2889. >>> g = Lorentz.metric
  2890. >>> p, q = tensor_heads('p,q', [Lorentz])
  2891. >>> t = p(m1)*g(m0,m2)
  2892. >>> t.get_indices()
  2893. [m1, m0, m2]
  2894. >>> t2 = p(m1)*g(-m1, m2)
  2895. >>> t2.get_indices()
  2896. [L_0, -L_0, m2]
  2897. """
  2898. return self._indices
  2899. def get_free_indices(self) -> list[TensorIndex]:
  2900. """
  2901. Returns the list of free indices of the tensor.
  2902. Explanation
  2903. ===========
  2904. The indices are listed in the order in which they appear in the
  2905. component tensors.
  2906. Examples
  2907. ========
  2908. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads
  2909. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  2910. >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz)
  2911. >>> g = Lorentz.metric
  2912. >>> p, q = tensor_heads('p,q', [Lorentz])
  2913. >>> t = p(m1)*g(m0,m2)
  2914. >>> t.get_free_indices()
  2915. [m1, m0, m2]
  2916. >>> t2 = p(m1)*g(-m1, m2)
  2917. >>> t2.get_free_indices()
  2918. [m2]
  2919. """
  2920. return self._index_structure.get_free_indices()
  2921. def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr:
  2922. return self.func(*[arg._replace_indices(repl) if isinstance(arg, TensExpr) else arg for arg in self.args])
  2923. def split(self):
  2924. """
  2925. Returns a list of tensors, whose product is ``self``.
  2926. Explanation
  2927. ===========
  2928. Dummy indices contracted among different tensor components
  2929. become free indices with the same name as the one used to
  2930. represent the dummy indices.
  2931. Examples
  2932. ========
  2933. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads, TensorSymmetry
  2934. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  2935. >>> a, b, c, d = tensor_indices('a,b,c,d', Lorentz)
  2936. >>> A, B = tensor_heads('A,B', [Lorentz]*2, TensorSymmetry.fully_symmetric(2))
  2937. >>> t = A(a,b)*B(-b,c)
  2938. >>> t
  2939. A(a, L_0)*B(-L_0, c)
  2940. >>> t.split()
  2941. [A(a, L_0), B(-L_0, c)]
  2942. """
  2943. if self.args == ():
  2944. return [self]
  2945. splitp = []
  2946. res = 1
  2947. for arg in self.args:
  2948. if isinstance(arg, Tensor):
  2949. splitp.append(res*arg)
  2950. res = 1
  2951. else:
  2952. res *= arg
  2953. return splitp
  2954. def _expand(self, **hints):
  2955. # TODO: temporary solution, in the future this should be linked to
  2956. # `Expr.expand`.
  2957. args = [_expand(arg, **hints) for arg in self.args]
  2958. args1 = [arg.args if isinstance(arg, (Add, TensAdd)) else (arg,) for arg in args]
  2959. return TensAdd(*[
  2960. TensMul(*i) for i in itertools.product(*args1)]
  2961. )
  2962. def __neg__(self):
  2963. return TensMul(S.NegativeOne, self, is_canon_bp=self._is_canon_bp).doit()
  2964. def __getitem__(self, item):
  2965. deprecate_data()
  2966. with ignore_warnings(SymPyDeprecationWarning):
  2967. return self.data[item]
  2968. def _get_args_for_traditional_printer(self):
  2969. args = list(self.args)
  2970. if self.coeff.could_extract_minus_sign():
  2971. # expressions like "-A(a)"
  2972. sign = "-"
  2973. if args[0] == S.NegativeOne:
  2974. args = args[1:]
  2975. else:
  2976. args[0] = -args[0]
  2977. else:
  2978. sign = ""
  2979. return sign, args
  2980. def _sort_args_for_sorted_components(self):
  2981. """
  2982. Returns the ``args`` sorted according to the components commutation
  2983. properties.
  2984. Explanation
  2985. ===========
  2986. The sorting is done taking into account the commutation group
  2987. of the component tensors.
  2988. """
  2989. cv = [arg for arg in self.args if isinstance(arg, TensExpr)]
  2990. sign = 1
  2991. n = len(cv) - 1
  2992. for i in range(n):
  2993. for j in range(n, i, -1):
  2994. c = cv[j-1].commutes_with(cv[j])
  2995. # if `c` is `None`, it does neither commute nor anticommute, skip:
  2996. if c not in (0, 1):
  2997. continue
  2998. typ1 = sorted(set(cv[j-1].component.index_types), key=lambda x: x.name)
  2999. typ2 = sorted(set(cv[j].component.index_types), key=lambda x: x.name)
  3000. if (typ1, cv[j-1].component.name) > (typ2, cv[j].component.name):
  3001. cv[j-1], cv[j] = cv[j], cv[j-1]
  3002. # if `c` is 1, the anticommute, so change sign:
  3003. if c:
  3004. sign = -sign
  3005. coeff = sign * self.coeff
  3006. if coeff != 1:
  3007. return [coeff] + cv
  3008. return cv
  3009. def sorted_components(self):
  3010. """
  3011. Returns a tensor product with sorted components.
  3012. """
  3013. return TensMul(*self._sort_args_for_sorted_components()).doit()
  3014. def perm2tensor(self, g, is_canon_bp=False):
  3015. """
  3016. Returns the tensor corresponding to the permutation ``g``
  3017. For further details, see the method in ``TIDS`` with the same name.
  3018. """
  3019. return perm2tensor(self, g, is_canon_bp=is_canon_bp)
  3020. def canon_bp(self):
  3021. """
  3022. Canonicalize using the Butler-Portugal algorithm for canonicalization
  3023. under monoterm symmetries.
  3024. Examples
  3025. ========
  3026. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, TensorSymmetry
  3027. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  3028. >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz)
  3029. >>> A = TensorHead('A', [Lorentz]*2, TensorSymmetry.fully_symmetric(-2))
  3030. >>> t = A(m0,-m1)*A(m1,-m0)
  3031. >>> t.canon_bp()
  3032. -A(L_0, L_1)*A(-L_0, -L_1)
  3033. >>> t = A(m0,-m1)*A(m1,-m2)*A(m2,-m0)
  3034. >>> t.canon_bp()
  3035. 0
  3036. """
  3037. if self._is_canon_bp:
  3038. return self
  3039. expr = self.expand()
  3040. if isinstance(expr, TensAdd):
  3041. return expr.canon_bp()
  3042. if not expr.components:
  3043. return expr
  3044. t = expr.sorted_components()
  3045. g, dummies, msym = t._index_structure.indices_canon_args()
  3046. v = components_canon_args(t.components)
  3047. can = canonicalize(g, dummies, msym, *v)
  3048. if can == 0:
  3049. return S.Zero
  3050. tmul = t.perm2tensor(can, True)
  3051. return tmul
  3052. def contract_delta(self, delta):
  3053. t = self.contract_metric(delta)
  3054. return t
  3055. def _get_indices_to_args_pos(self):
  3056. """
  3057. Get a dict mapping the index position to TensMul's argument number.
  3058. """
  3059. pos_map = {}
  3060. pos_counter = 0
  3061. for arg_i, arg in enumerate(self.args):
  3062. if not isinstance(arg, TensExpr):
  3063. continue
  3064. assert isinstance(arg, Tensor)
  3065. for i in range(arg.ext_rank):
  3066. pos_map[pos_counter] = arg_i
  3067. pos_counter += 1
  3068. return pos_map
  3069. def contract_metric(self, g):
  3070. """
  3071. Raise or lower indices with the metric ``g``.
  3072. Parameters
  3073. ==========
  3074. g : metric
  3075. Notes
  3076. =====
  3077. See the ``TensorIndexType`` docstring for the contraction conventions.
  3078. Examples
  3079. ========
  3080. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, tensor_heads
  3081. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  3082. >>> m0, m1, m2 = tensor_indices('m0,m1,m2', Lorentz)
  3083. >>> g = Lorentz.metric
  3084. >>> p, q = tensor_heads('p,q', [Lorentz])
  3085. >>> t = p(m0)*q(m1)*g(-m0, -m1)
  3086. >>> t.canon_bp()
  3087. metric(L_0, L_1)*p(-L_0)*q(-L_1)
  3088. >>> t.contract_metric(g).canon_bp()
  3089. p(L_0)*q(-L_0)
  3090. """
  3091. expr = self.expand()
  3092. if self != expr:
  3093. expr = canon_bp(expr)
  3094. return contract_metric(expr, g)
  3095. pos_map = self._get_indices_to_args_pos()
  3096. args = list(self.args)
  3097. #antisym = g.index_types[0].metric_antisym
  3098. if g.symmetry == TensorSymmetry.fully_symmetric(-2):
  3099. antisym = 1
  3100. elif g.symmetry == TensorSymmetry.fully_symmetric(2):
  3101. antisym = 0
  3102. elif g.symmetry == TensorSymmetry.no_symmetry(2):
  3103. antisym = None
  3104. else:
  3105. raise NotImplementedError
  3106. # list of positions of the metric ``g`` inside ``args``
  3107. gpos = [i for i, x in enumerate(self.args) if isinstance(x, Tensor) and x.component == g]
  3108. if not gpos:
  3109. return self
  3110. # Sign is either 1 or -1, to correct the sign after metric contraction
  3111. # (for spinor indices).
  3112. sign = 1
  3113. dum = self.dum[:]
  3114. free = self.free[:]
  3115. elim = set()
  3116. for gposx in gpos:
  3117. if gposx in elim:
  3118. continue
  3119. free1 = [x for x in free if pos_map[x[1]] == gposx]
  3120. dum1 = [x for x in dum if pos_map[x[0]] == gposx or pos_map[x[1]] == gposx]
  3121. if not dum1:
  3122. continue
  3123. elim.add(gposx)
  3124. # subs with the multiplication neutral element, that is, remove it:
  3125. args[gposx] = 1
  3126. if len(dum1) == 2:
  3127. if not antisym:
  3128. dum10, dum11 = dum1
  3129. if pos_map[dum10[1]] == gposx:
  3130. # the index with pos p0 contravariant
  3131. p0 = dum10[0]
  3132. else:
  3133. # the index with pos p0 is covariant
  3134. p0 = dum10[1]
  3135. if pos_map[dum11[1]] == gposx:
  3136. # the index with pos p1 is contravariant
  3137. p1 = dum11[0]
  3138. else:
  3139. # the index with pos p1 is covariant
  3140. p1 = dum11[1]
  3141. dum.append((p0, p1))
  3142. else:
  3143. dum10, dum11 = dum1
  3144. # change the sign to bring the indices of the metric to contravariant
  3145. # form; change the sign if dum10 has the metric index in position 0
  3146. if pos_map[dum10[1]] == gposx:
  3147. # the index with pos p0 is contravariant
  3148. p0 = dum10[0]
  3149. if dum10[1] == 1:
  3150. sign = -sign
  3151. else:
  3152. # the index with pos p0 is covariant
  3153. p0 = dum10[1]
  3154. if dum10[0] == 0:
  3155. sign = -sign
  3156. if pos_map[dum11[1]] == gposx:
  3157. # the index with pos p1 is contravariant
  3158. p1 = dum11[0]
  3159. sign = -sign
  3160. else:
  3161. # the index with pos p1 is covariant
  3162. p1 = dum11[1]
  3163. dum.append((p0, p1))
  3164. elif len(dum1) == 1:
  3165. if not antisym:
  3166. dp0, dp1 = dum1[0]
  3167. if pos_map[dp0] == pos_map[dp1]:
  3168. # g(i, -i)
  3169. typ = g.index_types[0]
  3170. sign = sign*typ.dim
  3171. else:
  3172. # g(i0, i1)*p(-i1)
  3173. if pos_map[dp0] == gposx:
  3174. p1 = dp1
  3175. else:
  3176. p1 = dp0
  3177. ind, p = free1[0]
  3178. free.append((ind, p1))
  3179. else:
  3180. dp0, dp1 = dum1[0]
  3181. if pos_map[dp0] == pos_map[dp1]:
  3182. # g(i, -i)
  3183. typ = g.index_types[0]
  3184. sign = sign*typ.dim
  3185. if dp0 < dp1:
  3186. # g(i, -i) = -D with antisymmetric metric
  3187. sign = -sign
  3188. else:
  3189. # g(i0, i1)*p(-i1)
  3190. if pos_map[dp0] == gposx:
  3191. p1 = dp1
  3192. if dp0 == 0:
  3193. sign = -sign
  3194. else:
  3195. p1 = dp0
  3196. ind, p = free1[0]
  3197. free.append((ind, p1))
  3198. dum = [x for x in dum if x not in dum1]
  3199. free = [x for x in free if x not in free1]
  3200. # shift positions:
  3201. shift = 0
  3202. shifts = [0]*len(args)
  3203. for i in range(len(args)):
  3204. if i in elim:
  3205. shift += 2
  3206. continue
  3207. shifts[i] = shift
  3208. free = [(ind, p - shifts[pos_map[p]]) for (ind, p) in free if pos_map[p] not in elim]
  3209. dum = [(p0 - shifts[pos_map[p0]], p1 - shifts[pos_map[p1]]) for i, (p0, p1) in enumerate(dum) if pos_map[p0] not in elim and pos_map[p1] not in elim]
  3210. res = sign*TensMul(*args).doit()
  3211. if not isinstance(res, TensExpr):
  3212. return res
  3213. im = _IndexStructure.from_components_free_dum(res.components, free, dum)
  3214. return res._set_new_index_structure(im)
  3215. def _set_new_index_structure(self, im, is_canon_bp=False):
  3216. indices = im.get_indices()
  3217. return self._set_indices(*indices, is_canon_bp=is_canon_bp)
  3218. def _set_indices(self, *indices, is_canon_bp=False, **kw_args):
  3219. if len(indices) != self.ext_rank:
  3220. raise ValueError("indices length mismatch")
  3221. args = list(self.args)[:]
  3222. pos = 0
  3223. for i, arg in enumerate(args):
  3224. if not isinstance(arg, TensExpr):
  3225. continue
  3226. assert isinstance(arg, Tensor)
  3227. ext_rank = arg.ext_rank
  3228. args[i] = arg._set_indices(*indices[pos:pos+ext_rank])
  3229. pos += ext_rank
  3230. return TensMul(*args, is_canon_bp=is_canon_bp).doit()
  3231. @staticmethod
  3232. def _index_replacement_for_contract_metric(args, free, dum):
  3233. for arg in args:
  3234. if not isinstance(arg, TensExpr):
  3235. continue
  3236. assert isinstance(arg, Tensor)
  3237. def substitute_indices(self, *index_tuples):
  3238. new_args = []
  3239. for arg in self.args:
  3240. if isinstance(arg, TensExpr):
  3241. arg = arg.substitute_indices(*index_tuples)
  3242. new_args.append(arg)
  3243. return TensMul(*new_args).doit()
  3244. def __call__(self, *indices):
  3245. deprecate_call()
  3246. free_args = self.free_args
  3247. indices = list(indices)
  3248. if [x.tensor_index_type for x in indices] != [x.tensor_index_type for x in free_args]:
  3249. raise ValueError('incompatible types')
  3250. if indices == free_args:
  3251. return self
  3252. t = self.substitute_indices(*list(zip(free_args, indices)))
  3253. # object is rebuilt in order to make sure that all contracted indices
  3254. # get recognized as dummies, but only if there are contracted indices.
  3255. if len({i if i.is_up else -i for i in indices}) != len(indices):
  3256. return t.func(*t.args)
  3257. return t
  3258. def _extract_data(self, replacement_dict):
  3259. args_indices, arrays = zip(*[arg._extract_data(replacement_dict) for arg in self.args if isinstance(arg, TensExpr)])
  3260. coeff = reduce(operator.mul, [a for a in self.args if not isinstance(a, TensExpr)], S.One)
  3261. indices, free, free_names, dummy_data = TensMul._indices_to_free_dum(args_indices)
  3262. dum = TensMul._dummy_data_to_dum(dummy_data)
  3263. ext_rank = self.ext_rank
  3264. free.sort(key=lambda x: x[1])
  3265. free_indices = [i[0] for i in free]
  3266. return free_indices, coeff*_TensorDataLazyEvaluator.data_contract_dum(arrays, dum, ext_rank)
  3267. @property
  3268. def data(self):
  3269. deprecate_data()
  3270. with ignore_warnings(SymPyDeprecationWarning):
  3271. dat = _tensor_data_substitution_dict[self.expand()]
  3272. return dat
  3273. @data.setter
  3274. def data(self, data):
  3275. deprecate_data()
  3276. raise ValueError("Not possible to set component data to a tensor expression")
  3277. @data.deleter
  3278. def data(self):
  3279. deprecate_data()
  3280. raise ValueError("Not possible to delete component data to a tensor expression")
  3281. def __iter__(self):
  3282. deprecate_data()
  3283. with ignore_warnings(SymPyDeprecationWarning):
  3284. if self.data is None:
  3285. raise ValueError("No iteration on abstract tensors")
  3286. return self.data.__iter__()
  3287. @staticmethod
  3288. def _dedupe_indices(new, exclude):
  3289. """
  3290. exclude: set
  3291. new: TensExpr
  3292. If ``new`` has any dummy indices that are in ``exclude``, return a version
  3293. of new with those indices replaced. If no replacements are needed,
  3294. return None
  3295. """
  3296. exclude = set(exclude)
  3297. dums_new = set(get_dummy_indices(new))
  3298. free_new = set(get_free_indices(new))
  3299. conflicts = dums_new.intersection(exclude)
  3300. if len(conflicts) == 0:
  3301. return None
  3302. """
  3303. ``exclude_for_gen`` is to be passed to ``_IndexStructure._get_generator_for_dummy_indices()``.
  3304. Since the latter does not use the index position for anything, we just
  3305. set it as ``None`` here.
  3306. """
  3307. exclude.update(dums_new)
  3308. exclude.update(free_new)
  3309. exclude_for_gen = [(i, None) for i in exclude]
  3310. gen = _IndexStructure._get_generator_for_dummy_indices(exclude_for_gen)
  3311. repl = {}
  3312. for d in conflicts:
  3313. if -d in repl.keys():
  3314. continue
  3315. newname = gen(d.tensor_index_type)
  3316. new_d = d.func(newname, *d.args[1:])
  3317. repl[d] = new_d
  3318. repl[-d] = -new_d
  3319. if len(repl) == 0:
  3320. return None
  3321. new_renamed = new._replace_indices(repl)
  3322. return new_renamed
  3323. def _dedupe_indices_in_rule(self, rule):
  3324. """
  3325. rule: dict
  3326. This applies TensMul._dedupe_indices on all values of rule.
  3327. """
  3328. index_rules = {k:v for k,v in rule.items() if isinstance(k, TensorIndex)}
  3329. other_rules = {k:v for k,v in rule.items() if k not in index_rules.keys()}
  3330. exclude = set(self.get_indices())
  3331. newrule = {}
  3332. newrule.update(index_rules)
  3333. exclude.update(index_rules.keys())
  3334. exclude.update(index_rules.values())
  3335. for old, new in other_rules.items():
  3336. new_renamed = TensMul._dedupe_indices(new, exclude)
  3337. if old == new or new_renamed is None:
  3338. newrule[old] = new
  3339. else:
  3340. newrule[old] = new_renamed
  3341. exclude.update(get_indices(new_renamed))
  3342. return newrule
  3343. def _eval_rewrite_as_Indexed(self, *args):
  3344. from sympy.concrete.summations import Sum
  3345. index_symbols = [i.args[0] for i in self.get_indices()]
  3346. args = [arg.args[0] if isinstance(arg, Sum) else arg for arg in args]
  3347. expr = Mul.fromiter(args)
  3348. return self._check_add_Sum(expr, index_symbols)
  3349. def _eval_partial_derivative(self, s):
  3350. # Evaluation like Mul
  3351. terms = []
  3352. for i, arg in enumerate(self.args):
  3353. # checking whether some tensor instance is differentiated
  3354. # or some other thing is necessary, but ugly
  3355. if isinstance(arg, TensExpr):
  3356. d = arg._eval_partial_derivative(s)
  3357. else:
  3358. # do not call diff is s is no symbol
  3359. if s._diff_wrt:
  3360. d = arg._eval_derivative(s)
  3361. else:
  3362. d = S.Zero
  3363. if d:
  3364. terms.append(TensMul.fromiter(self.args[:i] + (d,) + self.args[i + 1:]))
  3365. return TensAdd.fromiter(terms)
  3366. class TensorElement(TensExpr):
  3367. """
  3368. Tensor with evaluated components.
  3369. Examples
  3370. ========
  3371. >>> from sympy.tensor.tensor import TensorIndexType, TensorHead, TensorSymmetry
  3372. >>> from sympy import symbols
  3373. >>> L = TensorIndexType("L")
  3374. >>> i, j, k = symbols("i j k")
  3375. >>> A = TensorHead("A", [L, L], TensorSymmetry.fully_symmetric(2))
  3376. >>> A(i, j).get_free_indices()
  3377. [i, j]
  3378. If we want to set component ``i`` to a specific value, use the
  3379. ``TensorElement`` class:
  3380. >>> from sympy.tensor.tensor import TensorElement
  3381. >>> te = TensorElement(A(i, j), {i: 2})
  3382. As index ``i`` has been accessed (``{i: 2}`` is the evaluation of its 3rd
  3383. element), the free indices will only contain ``j``:
  3384. >>> te.get_free_indices()
  3385. [j]
  3386. """
  3387. def __new__(cls, expr, index_map):
  3388. if not isinstance(expr, Tensor):
  3389. # remap
  3390. if not isinstance(expr, TensExpr):
  3391. raise TypeError("%s is not a tensor expression" % expr)
  3392. return expr.func(*[TensorElement(arg, index_map) for arg in expr.args])
  3393. expr_free_indices = expr.get_free_indices()
  3394. name_translation = {i.args[0]: i for i in expr_free_indices}
  3395. index_map = {name_translation.get(index, index): value for index, value in index_map.items()}
  3396. index_map = {index: value for index, value in index_map.items() if index in expr_free_indices}
  3397. if len(index_map) == 0:
  3398. return expr
  3399. free_indices = [i for i in expr_free_indices if i not in index_map.keys()]
  3400. index_map = Dict(index_map)
  3401. obj = TensExpr.__new__(cls, expr, index_map)
  3402. obj._free_indices = free_indices
  3403. return obj
  3404. @property
  3405. def free(self):
  3406. return [(index, i) for i, index in enumerate(self.get_free_indices())]
  3407. @property
  3408. def dum(self):
  3409. # TODO: inherit dummies from expr
  3410. return []
  3411. @property
  3412. def expr(self):
  3413. return self._args[0]
  3414. @property
  3415. def index_map(self):
  3416. return self._args[1]
  3417. @property
  3418. def coeff(self):
  3419. return S.One
  3420. @property
  3421. def nocoeff(self):
  3422. return self
  3423. def get_free_indices(self):
  3424. return self._free_indices
  3425. def _replace_indices(self, repl: dict[TensorIndex, TensorIndex]) -> TensExpr:
  3426. # TODO: can be improved:
  3427. return self.xreplace(repl)
  3428. def get_indices(self):
  3429. return self.get_free_indices()
  3430. def _extract_data(self, replacement_dict):
  3431. ret_indices, array = self.expr._extract_data(replacement_dict)
  3432. index_map = self.index_map
  3433. slice_tuple = tuple(index_map.get(i, slice(None)) for i in ret_indices)
  3434. ret_indices = [i for i in ret_indices if i not in index_map]
  3435. array = array.__getitem__(slice_tuple)
  3436. return ret_indices, array
  3437. class WildTensorHead(TensorHead):
  3438. """
  3439. A wild object that is used to create ``WildTensor`` instances
  3440. Explanation
  3441. ===========
  3442. Examples
  3443. ========
  3444. >>> from sympy.tensor.tensor import TensorHead, TensorIndex, WildTensorHead, TensorIndexType
  3445. >>> R3 = TensorIndexType('R3', dim=3)
  3446. >>> p = TensorIndex('p', R3)
  3447. >>> q = TensorIndex('q', R3)
  3448. A WildTensorHead can be created without specifying a ``TensorIndexType``
  3449. >>> W = WildTensorHead("W")
  3450. Calling it with a ``TensorIndex`` creates a ``WildTensor`` instance.
  3451. >>> type(W(p))
  3452. <class 'sympy.tensor.tensor.WildTensor'>
  3453. The ``TensorIndexType`` is automatically detected from the index that is passed
  3454. >>> W(p).component
  3455. W(R3)
  3456. Calling it with no indices returns an object that can match tensors with any number of indices.
  3457. >>> K = TensorHead('K', [R3])
  3458. >>> Q = TensorHead('Q', [R3, R3])
  3459. >>> W().matches(K(p))
  3460. {W: K(p)}
  3461. >>> W().matches(Q(p,q))
  3462. {W: Q(p, q)}
  3463. If you want to ignore the order of indices while matching, pass ``unordered_indices=True``.
  3464. >>> U = WildTensorHead("U", unordered_indices=True)
  3465. >>> W(p,q).matches(Q(q,p))
  3466. >>> U(p,q).matches(Q(q,p))
  3467. {U(R3,R3): _WildTensExpr(Q(q, p))}
  3468. Parameters
  3469. ==========
  3470. name : name of the tensor
  3471. unordered_indices : whether the order of the indices matters for matching
  3472. (default: False)
  3473. See also
  3474. ========
  3475. ``WildTensor``
  3476. ``TensorHead``
  3477. """
  3478. def __new__(cls, name, index_types=None, symmetry=None, comm=0, unordered_indices=False):
  3479. if isinstance(name, str):
  3480. name_symbol = Symbol(name)
  3481. elif isinstance(name, Symbol):
  3482. name_symbol = name
  3483. else:
  3484. raise ValueError("invalid name")
  3485. if index_types is None:
  3486. index_types = []
  3487. if symmetry is None:
  3488. symmetry = TensorSymmetry.no_symmetry(len(index_types))
  3489. else:
  3490. assert symmetry.rank == len(index_types)
  3491. if symmetry != TensorSymmetry.no_symmetry(len(index_types)):
  3492. raise NotImplementedError("Wild matching based on symmetry is not implemented.")
  3493. obj = Basic.__new__(cls, name_symbol, Tuple(*index_types), sympify(symmetry), sympify(comm), sympify(unordered_indices))
  3494. obj.comm = TensorManager.comm_symbols2i(comm)
  3495. obj.unordered_indices = unordered_indices
  3496. return obj
  3497. def __call__(self, *indices, **kwargs):
  3498. tensor = WildTensor(self, indices, **kwargs)
  3499. return tensor.doit()
  3500. class WildTensor(Tensor):
  3501. """
  3502. A wild object which matches ``Tensor`` instances
  3503. Explanation
  3504. ===========
  3505. This is instantiated by attaching indices to a ``WildTensorHead`` instance.
  3506. Examples
  3507. ========
  3508. >>> from sympy.tensor.tensor import TensorHead, TensorIndex, WildTensorHead, TensorIndexType
  3509. >>> W = WildTensorHead("W")
  3510. >>> R3 = TensorIndexType('R3', dim=3)
  3511. >>> p = TensorIndex('p', R3)
  3512. >>> q = TensorIndex('q', R3)
  3513. >>> K = TensorHead('K', [R3])
  3514. >>> Q = TensorHead('Q', [R3, R3])
  3515. Matching also takes the indices into account
  3516. >>> W(p).matches(K(p))
  3517. {W(R3): _WildTensExpr(K(p))}
  3518. >>> W(p).matches(K(q))
  3519. >>> W(p).matches(K(-p))
  3520. If you want to match objects with any number of indices, just use a ``WildTensor`` with no indices.
  3521. >>> W().matches(K(p))
  3522. {W: K(p)}
  3523. >>> W().matches(Q(p,q))
  3524. {W: Q(p, q)}
  3525. See Also
  3526. ========
  3527. ``WildTensorHead``
  3528. ``Tensor``
  3529. """
  3530. def __new__(cls, tensor_head, indices, **kw_args):
  3531. is_canon_bp = kw_args.pop("is_canon_bp", False)
  3532. if tensor_head.func == TensorHead:
  3533. """
  3534. If someone tried to call WildTensor by supplying a TensorHead (not a WildTensorHead), return a normal tensor instead. This is helpful when using subs on an expression to replace occurrences of a WildTensorHead with a TensorHead.
  3535. """
  3536. return Tensor(tensor_head, indices, is_canon_bp=is_canon_bp, **kw_args)
  3537. elif tensor_head.func == _WildTensExpr:
  3538. return tensor_head(*indices)
  3539. indices = cls._parse_indices(tensor_head, indices)
  3540. index_types = [ind.tensor_index_type for ind in indices]
  3541. tensor_head = tensor_head.func(
  3542. tensor_head.name,
  3543. index_types,
  3544. symmetry=None,
  3545. comm=tensor_head.comm,
  3546. unordered_indices=tensor_head.unordered_indices,
  3547. )
  3548. obj = Basic.__new__(cls, tensor_head, Tuple(*indices))
  3549. obj.name = tensor_head.name
  3550. obj._index_structure = _IndexStructure.from_indices(*indices)
  3551. obj._free = obj._index_structure.free[:]
  3552. obj._dum = obj._index_structure.dum[:]
  3553. obj._ext_rank = obj._index_structure._ext_rank
  3554. obj._coeff = S.One
  3555. obj._nocoeff = obj
  3556. obj._component = tensor_head
  3557. obj._components = [tensor_head]
  3558. if tensor_head.rank != len(indices):
  3559. raise ValueError("wrong number of indices")
  3560. obj.is_canon_bp = is_canon_bp
  3561. obj._index_map = obj._build_index_map(indices, obj._index_structure)
  3562. return obj
  3563. def matches(self, expr, repl_dict=None, old=False):
  3564. if not isinstance(expr, TensExpr) and expr != S(1):
  3565. return None
  3566. if repl_dict is None:
  3567. repl_dict = {}
  3568. else:
  3569. repl_dict = repl_dict.copy()
  3570. if len(self.indices) > 0:
  3571. if not hasattr(expr, "get_free_indices"):
  3572. return None
  3573. expr_indices = expr.get_free_indices()
  3574. if len(expr_indices) != len(self.indices):
  3575. return None
  3576. if self._component.unordered_indices:
  3577. m = self._match_indices_ignoring_order(expr)
  3578. if m is None:
  3579. return None
  3580. else:
  3581. repl_dict.update(m)
  3582. else:
  3583. for i in range(len(expr_indices)):
  3584. m = self.indices[i].matches(expr_indices[i])
  3585. if m is None:
  3586. return None
  3587. else:
  3588. repl_dict.update(m)
  3589. repl_dict[self.component] = _WildTensExpr(expr)
  3590. else:
  3591. #If no indices were passed to the WildTensor, it may match tensors with any number of indices.
  3592. repl_dict[self] = expr
  3593. return repl_dict
  3594. def _match_indices_ignoring_order(self, expr, repl_dict=None, old=False):
  3595. """
  3596. Helper method for matches. Checks if the indices of self and expr
  3597. match disregarding index ordering.
  3598. """
  3599. if repl_dict is None:
  3600. repl_dict = {}
  3601. else:
  3602. repl_dict = repl_dict.copy()
  3603. def siftkey(ind):
  3604. if isinstance(ind, WildTensorIndex):
  3605. if ind.ignore_updown:
  3606. return "wild, updown"
  3607. else:
  3608. return "wild"
  3609. else:
  3610. return "nonwild"
  3611. indices_sifted = sift(self.indices, siftkey)
  3612. matched_indices = []
  3613. expr_indices_remaining = expr.get_indices()
  3614. for ind in indices_sifted["nonwild"]:
  3615. matched_this_ind = False
  3616. for e_ind in expr_indices_remaining:
  3617. if e_ind in matched_indices:
  3618. continue
  3619. m = ind.matches(e_ind)
  3620. if m is not None:
  3621. matched_this_ind = True
  3622. repl_dict.update(m)
  3623. matched_indices.append(e_ind)
  3624. break
  3625. if not matched_this_ind:
  3626. return None
  3627. expr_indices_remaining = [i for i in expr_indices_remaining if i not in matched_indices]
  3628. for ind in indices_sifted["wild"]:
  3629. matched_this_ind = False
  3630. for e_ind in expr_indices_remaining:
  3631. m = ind.matches(e_ind)
  3632. if m is not None:
  3633. if -ind in repl_dict.keys() and -repl_dict[-ind] != m[ind]:
  3634. return None
  3635. matched_this_ind = True
  3636. repl_dict.update(m)
  3637. matched_indices.append(e_ind)
  3638. break
  3639. if not matched_this_ind:
  3640. return None
  3641. expr_indices_remaining = [i for i in expr_indices_remaining if i not in matched_indices]
  3642. for ind in indices_sifted["wild, updown"]:
  3643. matched_this_ind = False
  3644. for e_ind in expr_indices_remaining:
  3645. m = ind.matches(e_ind)
  3646. if m is not None:
  3647. if -ind in repl_dict.keys() and -repl_dict[-ind] != m[ind]:
  3648. return None
  3649. matched_this_ind = True
  3650. repl_dict.update(m)
  3651. matched_indices.append(e_ind)
  3652. break
  3653. if not matched_this_ind:
  3654. return None
  3655. if len(matched_indices) < len(self.indices):
  3656. return None
  3657. else:
  3658. return repl_dict
  3659. class WildTensorIndex(TensorIndex):
  3660. """
  3661. A wild object that matches TensorIndex instances.
  3662. Examples
  3663. ========
  3664. >>> from sympy.tensor.tensor import TensorIndex, TensorIndexType, WildTensorIndex
  3665. >>> R3 = TensorIndexType('R3', dim=3)
  3666. >>> p = TensorIndex("p", R3)
  3667. By default, covariant indices only match with covariant indices (and
  3668. similarly for contravariant)
  3669. >>> q = WildTensorIndex("q", R3)
  3670. >>> (q).matches(p)
  3671. {q: p}
  3672. >>> (q).matches(-p)
  3673. If you want matching to ignore whether the index is co/contra-variant, set
  3674. ignore_updown=True
  3675. >>> r = WildTensorIndex("r", R3, ignore_updown=True)
  3676. >>> (r).matches(-p)
  3677. {r: -p}
  3678. >>> (r).matches(p)
  3679. {r: p}
  3680. Parameters
  3681. ==========
  3682. name : name of the index (string), or ``True`` if you want it to be
  3683. automatically assigned
  3684. tensor_index_type : ``TensorIndexType`` of the index
  3685. is_up : flag for contravariant index (is_up=True by default)
  3686. ignore_updown : bool, Whether this should match both co- and contra-variant
  3687. indices (default:False)
  3688. """
  3689. def __new__(cls, name, tensor_index_type, is_up=True, ignore_updown=False):
  3690. if isinstance(name, str):
  3691. name_symbol = Symbol(name)
  3692. elif isinstance(name, Symbol):
  3693. name_symbol = name
  3694. elif name is True:
  3695. name = "_i{}".format(len(tensor_index_type._autogenerated))
  3696. name_symbol = Symbol(name)
  3697. tensor_index_type._autogenerated.append(name_symbol)
  3698. else:
  3699. raise ValueError("invalid name")
  3700. is_up = sympify(is_up)
  3701. ignore_updown = sympify(ignore_updown)
  3702. return Basic.__new__(cls, name_symbol, tensor_index_type, is_up, ignore_updown)
  3703. @property
  3704. def ignore_updown(self):
  3705. return self.args[3]
  3706. def __neg__(self):
  3707. t1 = WildTensorIndex(self.name, self.tensor_index_type,
  3708. (not self.is_up), self.ignore_updown)
  3709. return t1
  3710. def matches(self, expr, repl_dict=None, old=False):
  3711. if not isinstance(expr, TensorIndex):
  3712. return None
  3713. if self.tensor_index_type != expr.tensor_index_type:
  3714. return None
  3715. if not self.ignore_updown:
  3716. if self.is_up != expr.is_up:
  3717. return None
  3718. if repl_dict is None:
  3719. repl_dict = {}
  3720. else:
  3721. repl_dict = repl_dict.copy()
  3722. repl_dict[self] = expr
  3723. return repl_dict
  3724. class _WildTensExpr(Basic):
  3725. """
  3726. INTERNAL USE ONLY
  3727. This is an object that helps with replacement of WildTensors in expressions.
  3728. When this object is set as the tensor_head of a WildTensor, it replaces the
  3729. WildTensor by a TensExpr (passed when initializing this object).
  3730. Examples
  3731. ========
  3732. >>> from sympy.tensor.tensor import WildTensorHead, TensorIndex, TensorHead, TensorIndexType
  3733. >>> W = WildTensorHead("W")
  3734. >>> R3 = TensorIndexType('R3', dim=3)
  3735. >>> p = TensorIndex('p', R3)
  3736. >>> q = TensorIndex('q', R3)
  3737. >>> K = TensorHead('K', [R3])
  3738. >>> print( ( K(p) ).replace( W(p), W(q)*W(-q)*W(p) ) )
  3739. K(R_0)*K(-R_0)*K(p)
  3740. """
  3741. def __init__(self, expr):
  3742. if not isinstance(expr, TensExpr):
  3743. raise TypeError("_WildTensExpr expects a TensExpr as argument")
  3744. self.expr = expr
  3745. def __call__(self, *indices):
  3746. return self.expr._replace_indices(dict(zip(self.expr.get_free_indices(), indices)))
  3747. def __neg__(self):
  3748. return self.func(self.expr*S.NegativeOne)
  3749. def __abs__(self):
  3750. raise NotImplementedError
  3751. def __add__(self, other):
  3752. if other.func != self.func:
  3753. raise TypeError(f"Cannot add {self.func} to {other.func}")
  3754. return self.func(self.expr+other.expr)
  3755. def __radd__(self, other):
  3756. if other.func != self.func:
  3757. raise TypeError(f"Cannot add {self.func} to {other.func}")
  3758. return self.func(other.expr+self.expr)
  3759. def __sub__(self, other):
  3760. return self + (-other)
  3761. def __rsub__(self, other):
  3762. return other + (-self)
  3763. def __mul__(self, other):
  3764. raise NotImplementedError
  3765. def __rmul__(self, other):
  3766. raise NotImplementedError
  3767. def __truediv__(self, other):
  3768. raise NotImplementedError
  3769. def __rtruediv__(self, other):
  3770. raise NotImplementedError
  3771. def __pow__(self, other):
  3772. raise NotImplementedError
  3773. def __rpow__(self, other):
  3774. raise NotImplementedError
  3775. def canon_bp(p):
  3776. """
  3777. Butler-Portugal canonicalization. See ``tensor_can.py`` from the
  3778. combinatorics module for the details.
  3779. """
  3780. if isinstance(p, TensExpr):
  3781. return p.canon_bp()
  3782. return p
  3783. def tensor_mul(*a):
  3784. """
  3785. product of tensors
  3786. """
  3787. if not a:
  3788. return TensMul.from_data(S.One, [], [], [])
  3789. t = a[0]
  3790. for tx in a[1:]:
  3791. t = t*tx
  3792. return t
  3793. def riemann_cyclic_replace(t_r):
  3794. """
  3795. replace Riemann tensor with an equivalent expression
  3796. ``R(m,n,p,q) -> 2/3*R(m,n,p,q) - 1/3*R(m,q,n,p) + 1/3*R(m,p,n,q)``
  3797. """
  3798. free = sorted(t_r.free, key=lambda x: x[1])
  3799. m, n, p, q = [x[0] for x in free]
  3800. t0 = t_r*Rational(2, 3)
  3801. t1 = -t_r.substitute_indices((m,m),(n,q),(p,n),(q,p))*Rational(1, 3)
  3802. t2 = t_r.substitute_indices((m,m),(n,p),(p,n),(q,q))*Rational(1, 3)
  3803. t3 = t0 + t1 + t2
  3804. return t3
  3805. def riemann_cyclic(t2):
  3806. """
  3807. Replace each Riemann tensor with an equivalent expression
  3808. satisfying the cyclic identity.
  3809. This trick is discussed in the reference guide to Cadabra.
  3810. Examples
  3811. ========
  3812. >>> from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, riemann_cyclic, TensorSymmetry
  3813. >>> Lorentz = TensorIndexType('Lorentz', dummy_name='L')
  3814. >>> i, j, k, l = tensor_indices('i,j,k,l', Lorentz)
  3815. >>> R = TensorHead('R', [Lorentz]*4, TensorSymmetry.riemann())
  3816. >>> t = R(i,j,k,l)*(R(-i,-j,-k,-l) - 2*R(-i,-k,-j,-l))
  3817. >>> riemann_cyclic(t)
  3818. 0
  3819. """
  3820. t2 = t2.expand()
  3821. if isinstance(t2, (TensMul, Tensor)):
  3822. args = [t2]
  3823. else:
  3824. args = t2.args
  3825. a1 = [x.split() for x in args]
  3826. a2 = [[riemann_cyclic_replace(tx) for tx in y] for y in a1]
  3827. a3 = [tensor_mul(*v) for v in a2]
  3828. t3 = TensAdd(*a3).doit()
  3829. if not t3:
  3830. return t3
  3831. else:
  3832. return canon_bp(t3)
  3833. def get_lines(ex, index_type):
  3834. """
  3835. Returns ``(lines, traces, rest)`` for an index type,
  3836. where ``lines`` is the list of list of positions of a matrix line,
  3837. ``traces`` is the list of list of traced matrix lines,
  3838. ``rest`` is the rest of the elements of the tensor.
  3839. """
  3840. def _join_lines(a):
  3841. i = 0
  3842. while i < len(a):
  3843. x = a[i]
  3844. xend = x[-1]
  3845. xstart = x[0]
  3846. hit = True
  3847. while hit:
  3848. hit = False
  3849. for j in range(i + 1, len(a)):
  3850. if j >= len(a):
  3851. break
  3852. if a[j][0] == xend:
  3853. hit = True
  3854. x.extend(a[j][1:])
  3855. xend = x[-1]
  3856. a.pop(j)
  3857. continue
  3858. if a[j][0] == xstart:
  3859. hit = True
  3860. a[i] = reversed(a[j][1:]) + x
  3861. x = a[i]
  3862. xstart = a[i][0]
  3863. a.pop(j)
  3864. continue
  3865. if a[j][-1] == xend:
  3866. hit = True
  3867. x.extend(reversed(a[j][:-1]))
  3868. xend = x[-1]
  3869. a.pop(j)
  3870. continue
  3871. if a[j][-1] == xstart:
  3872. hit = True
  3873. a[i] = a[j][:-1] + x
  3874. x = a[i]
  3875. xstart = x[0]
  3876. a.pop(j)
  3877. continue
  3878. i += 1
  3879. return a
  3880. arguments = ex.args
  3881. dt = {}
  3882. for c in ex.args:
  3883. if not isinstance(c, TensExpr):
  3884. continue
  3885. if c in dt:
  3886. continue
  3887. index_types = c.index_types
  3888. a = []
  3889. for i in range(len(index_types)):
  3890. if index_types[i] is index_type:
  3891. a.append(i)
  3892. if len(a) > 2:
  3893. raise ValueError('at most two indices of type %s allowed' % index_type)
  3894. if len(a) == 2:
  3895. dt[c] = a
  3896. #dum = ex.dum
  3897. lines = []
  3898. traces = []
  3899. traces1 = []
  3900. #indices_to_args_pos = ex._get_indices_to_args_pos()
  3901. # TODO: add a dum_to_components_map ?
  3902. for p0, p1, c0, c1 in ex.dum_in_args:
  3903. if arguments[c0] not in dt:
  3904. continue
  3905. if c0 == c1:
  3906. traces.append([c0])
  3907. continue
  3908. ta0 = dt[arguments[c0]]
  3909. ta1 = dt[arguments[c1]]
  3910. if p0 not in ta0:
  3911. continue
  3912. if ta0.index(p0) == ta1.index(p1):
  3913. # case gamma(i,s0,-s1) in c0, gamma(j,-s0,s2) in c1;
  3914. # to deal with this case one could add to the position
  3915. # a flag for transposition;
  3916. # one could write [(c0, False), (c1, True)]
  3917. raise NotImplementedError
  3918. # if p0 == ta0[1] then G in pos c0 is mult on the right by G in c1
  3919. # if p0 == ta0[0] then G in pos c1 is mult on the right by G in c0
  3920. ta0 = dt[arguments[c0]]
  3921. b0, b1 = (c0, c1) if p0 == ta0[1] else (c1, c0)
  3922. lines1 = lines[:]
  3923. for line in lines:
  3924. if line[-1] == b0:
  3925. if line[0] == b1:
  3926. n = line.index(min(line))
  3927. traces1.append(line)
  3928. traces.append(line[n:] + line[:n])
  3929. else:
  3930. line.append(b1)
  3931. break
  3932. elif line[0] == b1:
  3933. line.insert(0, b0)
  3934. break
  3935. else:
  3936. lines1.append([b0, b1])
  3937. lines = [x for x in lines1 if x not in traces1]
  3938. lines = _join_lines(lines)
  3939. rest = []
  3940. for line in lines:
  3941. for y in line:
  3942. rest.append(y)
  3943. for line in traces:
  3944. for y in line:
  3945. rest.append(y)
  3946. rest = [x for x in range(len(arguments)) if x not in rest]
  3947. return lines, traces, rest
  3948. def get_free_indices(t):
  3949. if not isinstance(t, TensExpr):
  3950. return ()
  3951. return t.get_free_indices()
  3952. def get_indices(t):
  3953. if not isinstance(t, TensExpr):
  3954. return ()
  3955. return t.get_indices()
  3956. def get_dummy_indices(t):
  3957. if not isinstance(t, TensExpr):
  3958. return ()
  3959. inds = t.get_indices()
  3960. free = t.get_free_indices()
  3961. return [i for i in inds if i not in free]
  3962. def get_index_structure(t):
  3963. if isinstance(t, TensExpr):
  3964. return t._index_structure
  3965. return _IndexStructure([], [], [], [])
  3966. def get_coeff(t):
  3967. if isinstance(t, Tensor):
  3968. return S.One
  3969. if isinstance(t, TensMul):
  3970. return t.coeff
  3971. if isinstance(t, TensExpr):
  3972. raise ValueError("no coefficient associated to this tensor expression")
  3973. return t
  3974. def contract_metric(t, g):
  3975. if isinstance(t, TensExpr):
  3976. return t.contract_metric(g)
  3977. return t
  3978. def perm2tensor(t, g, is_canon_bp=False):
  3979. """
  3980. Returns the tensor corresponding to the permutation ``g``
  3981. For further details, see the method in ``TIDS`` with the same name.
  3982. """
  3983. if not isinstance(t, TensExpr):
  3984. return t
  3985. elif isinstance(t, (Tensor, TensMul)):
  3986. nim = get_index_structure(t).perm2tensor(g, is_canon_bp=is_canon_bp)
  3987. res = t._set_new_index_structure(nim, is_canon_bp=is_canon_bp)
  3988. if g[-1] != len(g) - 1:
  3989. return -res
  3990. return res
  3991. raise NotImplementedError()
  3992. def substitute_indices(t, *index_tuples):
  3993. if not isinstance(t, TensExpr):
  3994. return t
  3995. return t.substitute_indices(*index_tuples)
  3996. def _expand(expr, **kwargs):
  3997. if isinstance(expr, TensExpr):
  3998. return expr._expand(**kwargs)
  3999. else:
  4000. return expr.expand(**kwargs)