test_signaltools.py 134 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575
  1. # -*- coding: utf-8 -*-
  2. import sys
  3. from concurrent.futures import ThreadPoolExecutor, as_completed
  4. from decimal import Decimal
  5. from itertools import product
  6. from math import gcd
  7. import pytest
  8. from pytest import raises as assert_raises
  9. from numpy.testing import (
  10. assert_equal,
  11. assert_almost_equal, assert_array_equal, assert_array_almost_equal,
  12. assert_allclose, assert_, assert_array_less,
  13. suppress_warnings)
  14. from numpy import array, arange
  15. import numpy as np
  16. from scipy.fft import fft
  17. from scipy.ndimage import correlate1d
  18. from scipy.optimize import fmin, linear_sum_assignment
  19. from scipy import signal
  20. from scipy.signal import (
  21. correlate, correlate2d, correlation_lags, convolve, convolve2d,
  22. fftconvolve, oaconvolve, choose_conv_method,
  23. hilbert, hilbert2, lfilter, lfilter_zi, filtfilt, butter, zpk2tf, zpk2sos,
  24. invres, invresz, vectorstrength, lfiltic, tf2sos, sosfilt, sosfiltfilt,
  25. sosfilt_zi, tf2zpk, BadCoefficients, detrend, unique_roots, residue,
  26. residuez)
  27. from scipy.signal.windows import hann
  28. from scipy.signal._signaltools import (_filtfilt_gust, _compute_factors,
  29. _group_poles)
  30. from scipy.signal._upfirdn import _upfirdn_modes
  31. from scipy._lib import _testutils
  32. class _TestConvolve:
  33. def test_basic(self):
  34. a = [3, 4, 5, 6, 5, 4]
  35. b = [1, 2, 3]
  36. c = convolve(a, b)
  37. assert_array_equal(c, array([3, 10, 22, 28, 32, 32, 23, 12]))
  38. def test_same(self):
  39. a = [3, 4, 5]
  40. b = [1, 2, 3, 4]
  41. c = convolve(a, b, mode="same")
  42. assert_array_equal(c, array([10, 22, 34]))
  43. def test_same_eq(self):
  44. a = [3, 4, 5]
  45. b = [1, 2, 3]
  46. c = convolve(a, b, mode="same")
  47. assert_array_equal(c, array([10, 22, 22]))
  48. def test_complex(self):
  49. x = array([1 + 1j, 2 + 1j, 3 + 1j])
  50. y = array([1 + 1j, 2 + 1j])
  51. z = convolve(x, y)
  52. assert_array_equal(z, array([2j, 2 + 6j, 5 + 8j, 5 + 5j]))
  53. def test_zero_rank(self):
  54. a = 1289
  55. b = 4567
  56. c = convolve(a, b)
  57. assert_equal(c, a * b)
  58. def test_broadcastable(self):
  59. a = np.arange(27).reshape(3, 3, 3)
  60. b = np.arange(3)
  61. for i in range(3):
  62. b_shape = [1]*3
  63. b_shape[i] = 3
  64. x = convolve(a, b.reshape(b_shape), method='direct')
  65. y = convolve(a, b.reshape(b_shape), method='fft')
  66. assert_allclose(x, y)
  67. def test_single_element(self):
  68. a = array([4967])
  69. b = array([3920])
  70. c = convolve(a, b)
  71. assert_equal(c, a * b)
  72. def test_2d_arrays(self):
  73. a = [[1, 2, 3], [3, 4, 5]]
  74. b = [[2, 3, 4], [4, 5, 6]]
  75. c = convolve(a, b)
  76. d = array([[2, 7, 16, 17, 12],
  77. [10, 30, 62, 58, 38],
  78. [12, 31, 58, 49, 30]])
  79. assert_array_equal(c, d)
  80. def test_input_swapping(self):
  81. small = arange(8).reshape(2, 2, 2)
  82. big = 1j * arange(27).reshape(3, 3, 3)
  83. big += arange(27)[::-1].reshape(3, 3, 3)
  84. out_array = array(
  85. [[[0 + 0j, 26 + 0j, 25 + 1j, 24 + 2j],
  86. [52 + 0j, 151 + 5j, 145 + 11j, 93 + 11j],
  87. [46 + 6j, 133 + 23j, 127 + 29j, 81 + 23j],
  88. [40 + 12j, 98 + 32j, 93 + 37j, 54 + 24j]],
  89. [[104 + 0j, 247 + 13j, 237 + 23j, 135 + 21j],
  90. [282 + 30j, 632 + 96j, 604 + 124j, 330 + 86j],
  91. [246 + 66j, 548 + 180j, 520 + 208j, 282 + 134j],
  92. [142 + 66j, 307 + 161j, 289 + 179j, 153 + 107j]],
  93. [[68 + 36j, 157 + 103j, 147 + 113j, 81 + 75j],
  94. [174 + 138j, 380 + 348j, 352 + 376j, 186 + 230j],
  95. [138 + 174j, 296 + 432j, 268 + 460j, 138 + 278j],
  96. [70 + 138j, 145 + 323j, 127 + 341j, 63 + 197j]],
  97. [[32 + 72j, 68 + 166j, 59 + 175j, 30 + 100j],
  98. [68 + 192j, 139 + 433j, 117 + 455j, 57 + 255j],
  99. [38 + 222j, 73 + 499j, 51 + 521j, 21 + 291j],
  100. [12 + 144j, 20 + 318j, 7 + 331j, 0 + 182j]]])
  101. assert_array_equal(convolve(small, big, 'full'), out_array)
  102. assert_array_equal(convolve(big, small, 'full'), out_array)
  103. assert_array_equal(convolve(small, big, 'same'),
  104. out_array[1:3, 1:3, 1:3])
  105. assert_array_equal(convolve(big, small, 'same'),
  106. out_array[0:3, 0:3, 0:3])
  107. assert_array_equal(convolve(small, big, 'valid'),
  108. out_array[1:3, 1:3, 1:3])
  109. assert_array_equal(convolve(big, small, 'valid'),
  110. out_array[1:3, 1:3, 1:3])
  111. def test_invalid_params(self):
  112. a = [3, 4, 5]
  113. b = [1, 2, 3]
  114. assert_raises(ValueError, convolve, a, b, mode='spam')
  115. assert_raises(ValueError, convolve, a, b, mode='eggs', method='fft')
  116. assert_raises(ValueError, convolve, a, b, mode='ham', method='direct')
  117. assert_raises(ValueError, convolve, a, b, mode='full', method='bacon')
  118. assert_raises(ValueError, convolve, a, b, mode='same', method='bacon')
  119. class TestConvolve(_TestConvolve):
  120. def test_valid_mode2(self):
  121. # See gh-5897
  122. a = [1, 2, 3, 6, 5, 3]
  123. b = [2, 3, 4, 5, 3, 4, 2, 2, 1]
  124. expected = [70, 78, 73, 65]
  125. out = convolve(a, b, 'valid')
  126. assert_array_equal(out, expected)
  127. out = convolve(b, a, 'valid')
  128. assert_array_equal(out, expected)
  129. a = [1 + 5j, 2 - 1j, 3 + 0j]
  130. b = [2 - 3j, 1 + 0j]
  131. expected = [2 - 3j, 8 - 10j]
  132. out = convolve(a, b, 'valid')
  133. assert_array_equal(out, expected)
  134. out = convolve(b, a, 'valid')
  135. assert_array_equal(out, expected)
  136. def test_same_mode(self):
  137. a = [1, 2, 3, 3, 1, 2]
  138. b = [1, 4, 3, 4, 5, 6, 7, 4, 3, 2, 1, 1, 3]
  139. c = convolve(a, b, 'same')
  140. d = array([57, 61, 63, 57, 45, 36])
  141. assert_array_equal(c, d)
  142. def test_invalid_shapes(self):
  143. # By "invalid," we mean that no one
  144. # array has dimensions that are all at
  145. # least as large as the corresponding
  146. # dimensions of the other array. This
  147. # setup should throw a ValueError.
  148. a = np.arange(1, 7).reshape((2, 3))
  149. b = np.arange(-6, 0).reshape((3, 2))
  150. assert_raises(ValueError, convolve, *(a, b), **{'mode': 'valid'})
  151. assert_raises(ValueError, convolve, *(b, a), **{'mode': 'valid'})
  152. def test_convolve_method(self, n=100):
  153. types = sum([t for _, t in np.sctypes.items()], [])
  154. types = {np.dtype(t).name for t in types}
  155. # These types include 'bool' and all precisions (int8, float32, etc)
  156. # The removed types throw errors in correlate or fftconvolve
  157. for dtype in ['complex256', 'complex192', 'float128', 'float96',
  158. 'str', 'void', 'bytes', 'object', 'unicode', 'string']:
  159. if dtype in types:
  160. types.remove(dtype)
  161. args = [(t1, t2, mode) for t1 in types for t2 in types
  162. for mode in ['valid', 'full', 'same']]
  163. # These are random arrays, which means test is much stronger than
  164. # convolving testing by convolving two np.ones arrays
  165. np.random.seed(42)
  166. array_types = {'i': np.random.choice([0, 1], size=n),
  167. 'f': np.random.randn(n)}
  168. array_types['b'] = array_types['u'] = array_types['i']
  169. array_types['c'] = array_types['f'] + 0.5j*array_types['f']
  170. for t1, t2, mode in args:
  171. x1 = array_types[np.dtype(t1).kind].astype(t1)
  172. x2 = array_types[np.dtype(t2).kind].astype(t2)
  173. results = {key: convolve(x1, x2, method=key, mode=mode)
  174. for key in ['fft', 'direct']}
  175. assert_equal(results['fft'].dtype, results['direct'].dtype)
  176. if 'bool' in t1 and 'bool' in t2:
  177. assert_equal(choose_conv_method(x1, x2), 'direct')
  178. continue
  179. # Found by experiment. Found approx smallest value for (rtol, atol)
  180. # threshold to have tests pass.
  181. if any([t in {'complex64', 'float32'} for t in [t1, t2]]):
  182. kwargs = {'rtol': 1.0e-4, 'atol': 1e-6}
  183. elif 'float16' in [t1, t2]:
  184. # atol is default for np.allclose
  185. kwargs = {'rtol': 1e-3, 'atol': 1e-3}
  186. else:
  187. # defaults for np.allclose (different from assert_allclose)
  188. kwargs = {'rtol': 1e-5, 'atol': 1e-8}
  189. assert_allclose(results['fft'], results['direct'], **kwargs)
  190. def test_convolve_method_large_input(self):
  191. # This is really a test that convolving two large integers goes to the
  192. # direct method even if they're in the fft method.
  193. for n in [10, 20, 50, 51, 52, 53, 54, 60, 62]:
  194. z = np.array([2**n], dtype=np.int64)
  195. fft = convolve(z, z, method='fft')
  196. direct = convolve(z, z, method='direct')
  197. # this is the case when integer precision gets to us
  198. # issue #6076 has more detail, hopefully more tests after resolved
  199. if n < 50:
  200. assert_equal(fft, direct)
  201. assert_equal(fft, 2**(2*n))
  202. assert_equal(direct, 2**(2*n))
  203. def test_mismatched_dims(self):
  204. # Input arrays should have the same number of dimensions
  205. assert_raises(ValueError, convolve, [1], 2, method='direct')
  206. assert_raises(ValueError, convolve, 1, [2], method='direct')
  207. assert_raises(ValueError, convolve, [1], 2, method='fft')
  208. assert_raises(ValueError, convolve, 1, [2], method='fft')
  209. assert_raises(ValueError, convolve, [1], [[2]])
  210. assert_raises(ValueError, convolve, [3], 2)
  211. class _TestConvolve2d:
  212. def test_2d_arrays(self):
  213. a = [[1, 2, 3], [3, 4, 5]]
  214. b = [[2, 3, 4], [4, 5, 6]]
  215. d = array([[2, 7, 16, 17, 12],
  216. [10, 30, 62, 58, 38],
  217. [12, 31, 58, 49, 30]])
  218. e = convolve2d(a, b)
  219. assert_array_equal(e, d)
  220. def test_valid_mode(self):
  221. e = [[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]]
  222. f = [[1, 2, 3], [3, 4, 5]]
  223. h = array([[62, 80, 98, 116, 134]])
  224. g = convolve2d(e, f, 'valid')
  225. assert_array_equal(g, h)
  226. # See gh-5897
  227. g = convolve2d(f, e, 'valid')
  228. assert_array_equal(g, h)
  229. def test_valid_mode_complx(self):
  230. e = [[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]]
  231. f = np.array([[1, 2, 3], [3, 4, 5]], dtype=complex) + 1j
  232. h = array([[62.+24.j, 80.+30.j, 98.+36.j, 116.+42.j, 134.+48.j]])
  233. g = convolve2d(e, f, 'valid')
  234. assert_array_almost_equal(g, h)
  235. # See gh-5897
  236. g = convolve2d(f, e, 'valid')
  237. assert_array_equal(g, h)
  238. def test_fillvalue(self):
  239. a = [[1, 2, 3], [3, 4, 5]]
  240. b = [[2, 3, 4], [4, 5, 6]]
  241. fillval = 1
  242. c = convolve2d(a, b, 'full', 'fill', fillval)
  243. d = array([[24, 26, 31, 34, 32],
  244. [28, 40, 62, 64, 52],
  245. [32, 46, 67, 62, 48]])
  246. assert_array_equal(c, d)
  247. def test_fillvalue_errors(self):
  248. msg = "could not cast `fillvalue` directly to the output "
  249. with np.testing.suppress_warnings() as sup:
  250. sup.filter(np.ComplexWarning, "Casting complex values")
  251. with assert_raises(ValueError, match=msg):
  252. convolve2d([[1]], [[1, 2]], fillvalue=1j)
  253. msg = "`fillvalue` must be scalar or an array with "
  254. with assert_raises(ValueError, match=msg):
  255. convolve2d([[1]], [[1, 2]], fillvalue=[1, 2])
  256. def test_fillvalue_empty(self):
  257. # Check that fillvalue being empty raises an error:
  258. assert_raises(ValueError, convolve2d, [[1]], [[1, 2]],
  259. fillvalue=[])
  260. def test_wrap_boundary(self):
  261. a = [[1, 2, 3], [3, 4, 5]]
  262. b = [[2, 3, 4], [4, 5, 6]]
  263. c = convolve2d(a, b, 'full', 'wrap')
  264. d = array([[80, 80, 74, 80, 80],
  265. [68, 68, 62, 68, 68],
  266. [80, 80, 74, 80, 80]])
  267. assert_array_equal(c, d)
  268. def test_sym_boundary(self):
  269. a = [[1, 2, 3], [3, 4, 5]]
  270. b = [[2, 3, 4], [4, 5, 6]]
  271. c = convolve2d(a, b, 'full', 'symm')
  272. d = array([[34, 30, 44, 62, 66],
  273. [52, 48, 62, 80, 84],
  274. [82, 78, 92, 110, 114]])
  275. assert_array_equal(c, d)
  276. @pytest.mark.parametrize('func', [convolve2d, correlate2d])
  277. @pytest.mark.parametrize('boundary, expected',
  278. [('symm', [[37.0, 42.0, 44.0, 45.0]]),
  279. ('wrap', [[43.0, 44.0, 42.0, 39.0]])])
  280. def test_same_with_boundary(self, func, boundary, expected):
  281. # Test boundary='symm' and boundary='wrap' with a "long" kernel.
  282. # The size of the kernel requires that the values in the "image"
  283. # be extended more than once to handle the requested boundary method.
  284. # This is a regression test for gh-8684 and gh-8814.
  285. image = np.array([[2.0, -1.0, 3.0, 4.0]])
  286. kernel = np.ones((1, 21))
  287. result = func(image, kernel, mode='same', boundary=boundary)
  288. # The expected results were calculated "by hand". Because the
  289. # kernel is all ones, the same result is expected for convolve2d
  290. # and correlate2d.
  291. assert_array_equal(result, expected)
  292. def test_boundary_extension_same(self):
  293. # Regression test for gh-12686.
  294. # Use ndimage.convolve with appropriate arguments to create the
  295. # expected result.
  296. import scipy.ndimage as ndi
  297. a = np.arange(1, 10*3+1, dtype=float).reshape(10, 3)
  298. b = np.arange(1, 10*10+1, dtype=float).reshape(10, 10)
  299. c = convolve2d(a, b, mode='same', boundary='wrap')
  300. assert_array_equal(c, ndi.convolve(a, b, mode='wrap', origin=(-1, -1)))
  301. def test_boundary_extension_full(self):
  302. # Regression test for gh-12686.
  303. # Use ndimage.convolve with appropriate arguments to create the
  304. # expected result.
  305. import scipy.ndimage as ndi
  306. a = np.arange(1, 3*3+1, dtype=float).reshape(3, 3)
  307. b = np.arange(1, 6*6+1, dtype=float).reshape(6, 6)
  308. c = convolve2d(a, b, mode='full', boundary='wrap')
  309. apad = np.pad(a, ((3, 3), (3, 3)), 'wrap')
  310. assert_array_equal(c, ndi.convolve(apad, b, mode='wrap')[:-1, :-1])
  311. def test_invalid_shapes(self):
  312. # By "invalid," we mean that no one
  313. # array has dimensions that are all at
  314. # least as large as the corresponding
  315. # dimensions of the other array. This
  316. # setup should throw a ValueError.
  317. a = np.arange(1, 7).reshape((2, 3))
  318. b = np.arange(-6, 0).reshape((3, 2))
  319. assert_raises(ValueError, convolve2d, *(a, b), **{'mode': 'valid'})
  320. assert_raises(ValueError, convolve2d, *(b, a), **{'mode': 'valid'})
  321. class TestConvolve2d(_TestConvolve2d):
  322. def test_same_mode(self):
  323. e = [[1, 2, 3], [3, 4, 5]]
  324. f = [[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]]
  325. g = convolve2d(e, f, 'same')
  326. h = array([[22, 28, 34],
  327. [80, 98, 116]])
  328. assert_array_equal(g, h)
  329. def test_valid_mode2(self):
  330. # See gh-5897
  331. e = [[1, 2, 3], [3, 4, 5]]
  332. f = [[2, 3, 4, 5, 6, 7, 8], [4, 5, 6, 7, 8, 9, 10]]
  333. expected = [[62, 80, 98, 116, 134]]
  334. out = convolve2d(e, f, 'valid')
  335. assert_array_equal(out, expected)
  336. out = convolve2d(f, e, 'valid')
  337. assert_array_equal(out, expected)
  338. e = [[1 + 1j, 2 - 3j], [3 + 1j, 4 + 0j]]
  339. f = [[2 - 1j, 3 + 2j, 4 + 0j], [4 - 0j, 5 + 1j, 6 - 3j]]
  340. expected = [[27 - 1j, 46. + 2j]]
  341. out = convolve2d(e, f, 'valid')
  342. assert_array_equal(out, expected)
  343. # See gh-5897
  344. out = convolve2d(f, e, 'valid')
  345. assert_array_equal(out, expected)
  346. def test_consistency_convolve_funcs(self):
  347. # Compare np.convolve, signal.convolve, signal.convolve2d
  348. a = np.arange(5)
  349. b = np.array([3.2, 1.4, 3])
  350. for mode in ['full', 'valid', 'same']:
  351. assert_almost_equal(np.convolve(a, b, mode=mode),
  352. signal.convolve(a, b, mode=mode))
  353. assert_almost_equal(np.squeeze(
  354. signal.convolve2d([a], [b], mode=mode)),
  355. signal.convolve(a, b, mode=mode))
  356. def test_invalid_dims(self):
  357. assert_raises(ValueError, convolve2d, 3, 4)
  358. assert_raises(ValueError, convolve2d, [3], [4])
  359. assert_raises(ValueError, convolve2d, [[[3]]], [[[4]]])
  360. @pytest.mark.slow
  361. @pytest.mark.xfail_on_32bit("Can't create large array for test")
  362. def test_large_array(self):
  363. # Test indexing doesn't overflow an int (gh-10761)
  364. n = 2**31 // (1000 * np.int64().itemsize)
  365. _testutils.check_free_memory(2 * n * 1001 * np.int64().itemsize / 1e6)
  366. # Create a chequered pattern of 1s and 0s
  367. a = np.zeros(1001 * n, dtype=np.int64)
  368. a[::2] = 1
  369. a = np.lib.stride_tricks.as_strided(a, shape=(n, 1000), strides=(8008, 8))
  370. count = signal.convolve2d(a, [[1, 1]])
  371. fails = np.where(count > 1)
  372. assert fails[0].size == 0
  373. class TestFFTConvolve:
  374. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  375. def test_real(self, axes):
  376. a = array([1, 2, 3])
  377. expected = array([1, 4, 10, 12, 9.])
  378. if axes == '':
  379. out = fftconvolve(a, a)
  380. else:
  381. out = fftconvolve(a, a, axes=axes)
  382. assert_array_almost_equal(out, expected)
  383. @pytest.mark.parametrize('axes', [1, [1], -1, [-1]])
  384. def test_real_axes(self, axes):
  385. a = array([1, 2, 3])
  386. expected = array([1, 4, 10, 12, 9.])
  387. a = np.tile(a, [2, 1])
  388. expected = np.tile(expected, [2, 1])
  389. out = fftconvolve(a, a, axes=axes)
  390. assert_array_almost_equal(out, expected)
  391. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  392. def test_complex(self, axes):
  393. a = array([1 + 1j, 2 + 2j, 3 + 3j])
  394. expected = array([0 + 2j, 0 + 8j, 0 + 20j, 0 + 24j, 0 + 18j])
  395. if axes == '':
  396. out = fftconvolve(a, a)
  397. else:
  398. out = fftconvolve(a, a, axes=axes)
  399. assert_array_almost_equal(out, expected)
  400. @pytest.mark.parametrize('axes', [1, [1], -1, [-1]])
  401. def test_complex_axes(self, axes):
  402. a = array([1 + 1j, 2 + 2j, 3 + 3j])
  403. expected = array([0 + 2j, 0 + 8j, 0 + 20j, 0 + 24j, 0 + 18j])
  404. a = np.tile(a, [2, 1])
  405. expected = np.tile(expected, [2, 1])
  406. out = fftconvolve(a, a, axes=axes)
  407. assert_array_almost_equal(out, expected)
  408. @pytest.mark.parametrize('axes', ['',
  409. None,
  410. [0, 1],
  411. [1, 0],
  412. [0, -1],
  413. [-1, 0],
  414. [-2, 1],
  415. [1, -2],
  416. [-2, -1],
  417. [-1, -2]])
  418. def test_2d_real_same(self, axes):
  419. a = array([[1, 2, 3],
  420. [4, 5, 6]])
  421. expected = array([[1, 4, 10, 12, 9],
  422. [8, 26, 56, 54, 36],
  423. [16, 40, 73, 60, 36]])
  424. if axes == '':
  425. out = fftconvolve(a, a)
  426. else:
  427. out = fftconvolve(a, a, axes=axes)
  428. assert_array_almost_equal(out, expected)
  429. @pytest.mark.parametrize('axes', [[1, 2],
  430. [2, 1],
  431. [1, -1],
  432. [-1, 1],
  433. [-2, 2],
  434. [2, -2],
  435. [-2, -1],
  436. [-1, -2]])
  437. def test_2d_real_same_axes(self, axes):
  438. a = array([[1, 2, 3],
  439. [4, 5, 6]])
  440. expected = array([[1, 4, 10, 12, 9],
  441. [8, 26, 56, 54, 36],
  442. [16, 40, 73, 60, 36]])
  443. a = np.tile(a, [2, 1, 1])
  444. expected = np.tile(expected, [2, 1, 1])
  445. out = fftconvolve(a, a, axes=axes)
  446. assert_array_almost_equal(out, expected)
  447. @pytest.mark.parametrize('axes', ['',
  448. None,
  449. [0, 1],
  450. [1, 0],
  451. [0, -1],
  452. [-1, 0],
  453. [-2, 1],
  454. [1, -2],
  455. [-2, -1],
  456. [-1, -2]])
  457. def test_2d_complex_same(self, axes):
  458. a = array([[1 + 2j, 3 + 4j, 5 + 6j],
  459. [2 + 1j, 4 + 3j, 6 + 5j]])
  460. expected = array([
  461. [-3 + 4j, -10 + 20j, -21 + 56j, -18 + 76j, -11 + 60j],
  462. [10j, 44j, 118j, 156j, 122j],
  463. [3 + 4j, 10 + 20j, 21 + 56j, 18 + 76j, 11 + 60j]
  464. ])
  465. if axes == '':
  466. out = fftconvolve(a, a)
  467. else:
  468. out = fftconvolve(a, a, axes=axes)
  469. assert_array_almost_equal(out, expected)
  470. @pytest.mark.parametrize('axes', [[1, 2],
  471. [2, 1],
  472. [1, -1],
  473. [-1, 1],
  474. [-2, 2],
  475. [2, -2],
  476. [-2, -1],
  477. [-1, -2]])
  478. def test_2d_complex_same_axes(self, axes):
  479. a = array([[1 + 2j, 3 + 4j, 5 + 6j],
  480. [2 + 1j, 4 + 3j, 6 + 5j]])
  481. expected = array([
  482. [-3 + 4j, -10 + 20j, -21 + 56j, -18 + 76j, -11 + 60j],
  483. [10j, 44j, 118j, 156j, 122j],
  484. [3 + 4j, 10 + 20j, 21 + 56j, 18 + 76j, 11 + 60j]
  485. ])
  486. a = np.tile(a, [2, 1, 1])
  487. expected = np.tile(expected, [2, 1, 1])
  488. out = fftconvolve(a, a, axes=axes)
  489. assert_array_almost_equal(out, expected)
  490. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  491. def test_real_same_mode(self, axes):
  492. a = array([1, 2, 3])
  493. b = array([3, 3, 5, 6, 8, 7, 9, 0, 1])
  494. expected_1 = array([35., 41., 47.])
  495. expected_2 = array([9., 20., 25., 35., 41., 47., 39., 28., 2.])
  496. if axes == '':
  497. out = fftconvolve(a, b, 'same')
  498. else:
  499. out = fftconvolve(a, b, 'same', axes=axes)
  500. assert_array_almost_equal(out, expected_1)
  501. if axes == '':
  502. out = fftconvolve(b, a, 'same')
  503. else:
  504. out = fftconvolve(b, a, 'same', axes=axes)
  505. assert_array_almost_equal(out, expected_2)
  506. @pytest.mark.parametrize('axes', [1, -1, [1], [-1]])
  507. def test_real_same_mode_axes(self, axes):
  508. a = array([1, 2, 3])
  509. b = array([3, 3, 5, 6, 8, 7, 9, 0, 1])
  510. expected_1 = array([35., 41., 47.])
  511. expected_2 = array([9., 20., 25., 35., 41., 47., 39., 28., 2.])
  512. a = np.tile(a, [2, 1])
  513. b = np.tile(b, [2, 1])
  514. expected_1 = np.tile(expected_1, [2, 1])
  515. expected_2 = np.tile(expected_2, [2, 1])
  516. out = fftconvolve(a, b, 'same', axes=axes)
  517. assert_array_almost_equal(out, expected_1)
  518. out = fftconvolve(b, a, 'same', axes=axes)
  519. assert_array_almost_equal(out, expected_2)
  520. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  521. def test_valid_mode_real(self, axes):
  522. # See gh-5897
  523. a = array([3, 2, 1])
  524. b = array([3, 3, 5, 6, 8, 7, 9, 0, 1])
  525. expected = array([24., 31., 41., 43., 49., 25., 12.])
  526. if axes == '':
  527. out = fftconvolve(a, b, 'valid')
  528. else:
  529. out = fftconvolve(a, b, 'valid', axes=axes)
  530. assert_array_almost_equal(out, expected)
  531. if axes == '':
  532. out = fftconvolve(b, a, 'valid')
  533. else:
  534. out = fftconvolve(b, a, 'valid', axes=axes)
  535. assert_array_almost_equal(out, expected)
  536. @pytest.mark.parametrize('axes', [1, [1]])
  537. def test_valid_mode_real_axes(self, axes):
  538. # See gh-5897
  539. a = array([3, 2, 1])
  540. b = array([3, 3, 5, 6, 8, 7, 9, 0, 1])
  541. expected = array([24., 31., 41., 43., 49., 25., 12.])
  542. a = np.tile(a, [2, 1])
  543. b = np.tile(b, [2, 1])
  544. expected = np.tile(expected, [2, 1])
  545. out = fftconvolve(a, b, 'valid', axes=axes)
  546. assert_array_almost_equal(out, expected)
  547. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  548. def test_valid_mode_complex(self, axes):
  549. a = array([3 - 1j, 2 + 7j, 1 + 0j])
  550. b = array([3 + 2j, 3 - 3j, 5 + 0j, 6 - 1j, 8 + 0j])
  551. expected = array([45. + 12.j, 30. + 23.j, 48 + 32.j])
  552. if axes == '':
  553. out = fftconvolve(a, b, 'valid')
  554. else:
  555. out = fftconvolve(a, b, 'valid', axes=axes)
  556. assert_array_almost_equal(out, expected)
  557. if axes == '':
  558. out = fftconvolve(b, a, 'valid')
  559. else:
  560. out = fftconvolve(b, a, 'valid', axes=axes)
  561. assert_array_almost_equal(out, expected)
  562. @pytest.mark.parametrize('axes', [1, [1], -1, [-1]])
  563. def test_valid_mode_complex_axes(self, axes):
  564. a = array([3 - 1j, 2 + 7j, 1 + 0j])
  565. b = array([3 + 2j, 3 - 3j, 5 + 0j, 6 - 1j, 8 + 0j])
  566. expected = array([45. + 12.j, 30. + 23.j, 48 + 32.j])
  567. a = np.tile(a, [2, 1])
  568. b = np.tile(b, [2, 1])
  569. expected = np.tile(expected, [2, 1])
  570. out = fftconvolve(a, b, 'valid', axes=axes)
  571. assert_array_almost_equal(out, expected)
  572. out = fftconvolve(b, a, 'valid', axes=axes)
  573. assert_array_almost_equal(out, expected)
  574. def test_valid_mode_ignore_nonaxes(self):
  575. # See gh-5897
  576. a = array([3, 2, 1])
  577. b = array([3, 3, 5, 6, 8, 7, 9, 0, 1])
  578. expected = array([24., 31., 41., 43., 49., 25., 12.])
  579. a = np.tile(a, [2, 1])
  580. b = np.tile(b, [1, 1])
  581. expected = np.tile(expected, [2, 1])
  582. out = fftconvolve(a, b, 'valid', axes=1)
  583. assert_array_almost_equal(out, expected)
  584. def test_empty(self):
  585. # Regression test for #1745: crashes with 0-length input.
  586. assert_(fftconvolve([], []).size == 0)
  587. assert_(fftconvolve([5, 6], []).size == 0)
  588. assert_(fftconvolve([], [7]).size == 0)
  589. def test_zero_rank(self):
  590. a = array(4967)
  591. b = array(3920)
  592. out = fftconvolve(a, b)
  593. assert_equal(out, a * b)
  594. def test_single_element(self):
  595. a = array([4967])
  596. b = array([3920])
  597. out = fftconvolve(a, b)
  598. assert_equal(out, a * b)
  599. @pytest.mark.parametrize('axes', ['', None, 0, [0], -1, [-1]])
  600. def test_random_data(self, axes):
  601. np.random.seed(1234)
  602. a = np.random.rand(1233) + 1j * np.random.rand(1233)
  603. b = np.random.rand(1321) + 1j * np.random.rand(1321)
  604. expected = np.convolve(a, b, 'full')
  605. if axes == '':
  606. out = fftconvolve(a, b, 'full')
  607. else:
  608. out = fftconvolve(a, b, 'full', axes=axes)
  609. assert_(np.allclose(out, expected, rtol=1e-10))
  610. @pytest.mark.parametrize('axes', [1, [1], -1, [-1]])
  611. def test_random_data_axes(self, axes):
  612. np.random.seed(1234)
  613. a = np.random.rand(1233) + 1j * np.random.rand(1233)
  614. b = np.random.rand(1321) + 1j * np.random.rand(1321)
  615. expected = np.convolve(a, b, 'full')
  616. a = np.tile(a, [2, 1])
  617. b = np.tile(b, [2, 1])
  618. expected = np.tile(expected, [2, 1])
  619. out = fftconvolve(a, b, 'full', axes=axes)
  620. assert_(np.allclose(out, expected, rtol=1e-10))
  621. @pytest.mark.parametrize('axes', [[1, 4],
  622. [4, 1],
  623. [1, -1],
  624. [-1, 1],
  625. [-4, 4],
  626. [4, -4],
  627. [-4, -1],
  628. [-1, -4]])
  629. def test_random_data_multidim_axes(self, axes):
  630. a_shape, b_shape = (123, 22), (132, 11)
  631. np.random.seed(1234)
  632. a = np.random.rand(*a_shape) + 1j * np.random.rand(*a_shape)
  633. b = np.random.rand(*b_shape) + 1j * np.random.rand(*b_shape)
  634. expected = convolve2d(a, b, 'full')
  635. a = a[:, :, None, None, None]
  636. b = b[:, :, None, None, None]
  637. expected = expected[:, :, None, None, None]
  638. a = np.moveaxis(a.swapaxes(0, 2), 1, 4)
  639. b = np.moveaxis(b.swapaxes(0, 2), 1, 4)
  640. expected = np.moveaxis(expected.swapaxes(0, 2), 1, 4)
  641. # use 1 for dimension 2 in a and 3 in b to test broadcasting
  642. a = np.tile(a, [2, 1, 3, 1, 1])
  643. b = np.tile(b, [2, 1, 1, 4, 1])
  644. expected = np.tile(expected, [2, 1, 3, 4, 1])
  645. out = fftconvolve(a, b, 'full', axes=axes)
  646. assert_allclose(out, expected, rtol=1e-10, atol=1e-10)
  647. @pytest.mark.slow
  648. @pytest.mark.parametrize(
  649. 'n',
  650. list(range(1, 100)) +
  651. list(range(1000, 1500)) +
  652. np.random.RandomState(1234).randint(1001, 10000, 5).tolist())
  653. def test_many_sizes(self, n):
  654. a = np.random.rand(n) + 1j * np.random.rand(n)
  655. b = np.random.rand(n) + 1j * np.random.rand(n)
  656. expected = np.convolve(a, b, 'full')
  657. out = fftconvolve(a, b, 'full')
  658. assert_allclose(out, expected, atol=1e-10)
  659. out = fftconvolve(a, b, 'full', axes=[0])
  660. assert_allclose(out, expected, atol=1e-10)
  661. def test_fft_nan(self):
  662. n = 1000
  663. rng = np.random.default_rng(43876432987)
  664. sig_nan = rng.standard_normal(n)
  665. for val in [np.nan, np.inf]:
  666. sig_nan[100] = val
  667. coeffs = signal.firwin(200, 0.2)
  668. with pytest.warns(RuntimeWarning, match="Use of fft convolution"):
  669. signal.convolve(sig_nan, coeffs, mode='same', method='fft')
  670. def fftconvolve_err(*args, **kwargs):
  671. raise RuntimeError('Fell back to fftconvolve')
  672. def gen_oa_shapes(sizes):
  673. return [(a, b) for a, b in product(sizes, repeat=2)
  674. if abs(a - b) > 3]
  675. def gen_oa_shapes_2d(sizes):
  676. shapes0 = gen_oa_shapes(sizes)
  677. shapes1 = gen_oa_shapes(sizes)
  678. shapes = [ishapes0+ishapes1 for ishapes0, ishapes1 in
  679. zip(shapes0, shapes1)]
  680. modes = ['full', 'valid', 'same']
  681. return [ishapes+(imode,) for ishapes, imode in product(shapes, modes)
  682. if imode != 'valid' or
  683. (ishapes[0] > ishapes[1] and ishapes[2] > ishapes[3]) or
  684. (ishapes[0] < ishapes[1] and ishapes[2] < ishapes[3])]
  685. def gen_oa_shapes_eq(sizes):
  686. return [(a, b) for a, b in product(sizes, repeat=2)
  687. if a >= b]
  688. class TestOAConvolve:
  689. @pytest.mark.slow()
  690. @pytest.mark.parametrize('shape_a_0, shape_b_0',
  691. gen_oa_shapes_eq(list(range(100)) +
  692. list(range(100, 1000, 23)))
  693. )
  694. def test_real_manylens(self, shape_a_0, shape_b_0):
  695. a = np.random.rand(shape_a_0)
  696. b = np.random.rand(shape_b_0)
  697. expected = fftconvolve(a, b)
  698. out = oaconvolve(a, b)
  699. assert_array_almost_equal(out, expected)
  700. @pytest.mark.parametrize('shape_a_0, shape_b_0',
  701. gen_oa_shapes([50, 47, 6, 4, 1]))
  702. @pytest.mark.parametrize('is_complex', [True, False])
  703. @pytest.mark.parametrize('mode', ['full', 'valid', 'same'])
  704. def test_1d_noaxes(self, shape_a_0, shape_b_0,
  705. is_complex, mode, monkeypatch):
  706. a = np.random.rand(shape_a_0)
  707. b = np.random.rand(shape_b_0)
  708. if is_complex:
  709. a = a + 1j*np.random.rand(shape_a_0)
  710. b = b + 1j*np.random.rand(shape_b_0)
  711. expected = fftconvolve(a, b, mode=mode)
  712. monkeypatch.setattr(signal._signaltools, 'fftconvolve',
  713. fftconvolve_err)
  714. out = oaconvolve(a, b, mode=mode)
  715. assert_array_almost_equal(out, expected)
  716. @pytest.mark.parametrize('axes', [0, 1])
  717. @pytest.mark.parametrize('shape_a_0, shape_b_0',
  718. gen_oa_shapes([50, 47, 6, 4]))
  719. @pytest.mark.parametrize('shape_a_extra', [1, 3])
  720. @pytest.mark.parametrize('shape_b_extra', [1, 3])
  721. @pytest.mark.parametrize('is_complex', [True, False])
  722. @pytest.mark.parametrize('mode', ['full', 'valid', 'same'])
  723. def test_1d_axes(self, axes, shape_a_0, shape_b_0,
  724. shape_a_extra, shape_b_extra,
  725. is_complex, mode, monkeypatch):
  726. ax_a = [shape_a_extra]*2
  727. ax_b = [shape_b_extra]*2
  728. ax_a[axes] = shape_a_0
  729. ax_b[axes] = shape_b_0
  730. a = np.random.rand(*ax_a)
  731. b = np.random.rand(*ax_b)
  732. if is_complex:
  733. a = a + 1j*np.random.rand(*ax_a)
  734. b = b + 1j*np.random.rand(*ax_b)
  735. expected = fftconvolve(a, b, mode=mode, axes=axes)
  736. monkeypatch.setattr(signal._signaltools, 'fftconvolve',
  737. fftconvolve_err)
  738. out = oaconvolve(a, b, mode=mode, axes=axes)
  739. assert_array_almost_equal(out, expected)
  740. @pytest.mark.parametrize('shape_a_0, shape_b_0, '
  741. 'shape_a_1, shape_b_1, mode',
  742. gen_oa_shapes_2d([50, 47, 6, 4]))
  743. @pytest.mark.parametrize('is_complex', [True, False])
  744. def test_2d_noaxes(self, shape_a_0, shape_b_0,
  745. shape_a_1, shape_b_1, mode,
  746. is_complex, monkeypatch):
  747. a = np.random.rand(shape_a_0, shape_a_1)
  748. b = np.random.rand(shape_b_0, shape_b_1)
  749. if is_complex:
  750. a = a + 1j*np.random.rand(shape_a_0, shape_a_1)
  751. b = b + 1j*np.random.rand(shape_b_0, shape_b_1)
  752. expected = fftconvolve(a, b, mode=mode)
  753. monkeypatch.setattr(signal._signaltools, 'fftconvolve',
  754. fftconvolve_err)
  755. out = oaconvolve(a, b, mode=mode)
  756. assert_array_almost_equal(out, expected)
  757. @pytest.mark.parametrize('axes', [[0, 1], [0, 2], [1, 2]])
  758. @pytest.mark.parametrize('shape_a_0, shape_b_0, '
  759. 'shape_a_1, shape_b_1, mode',
  760. gen_oa_shapes_2d([50, 47, 6, 4]))
  761. @pytest.mark.parametrize('shape_a_extra', [1, 3])
  762. @pytest.mark.parametrize('shape_b_extra', [1, 3])
  763. @pytest.mark.parametrize('is_complex', [True, False])
  764. def test_2d_axes(self, axes, shape_a_0, shape_b_0,
  765. shape_a_1, shape_b_1, mode,
  766. shape_a_extra, shape_b_extra,
  767. is_complex, monkeypatch):
  768. ax_a = [shape_a_extra]*3
  769. ax_b = [shape_b_extra]*3
  770. ax_a[axes[0]] = shape_a_0
  771. ax_b[axes[0]] = shape_b_0
  772. ax_a[axes[1]] = shape_a_1
  773. ax_b[axes[1]] = shape_b_1
  774. a = np.random.rand(*ax_a)
  775. b = np.random.rand(*ax_b)
  776. if is_complex:
  777. a = a + 1j*np.random.rand(*ax_a)
  778. b = b + 1j*np.random.rand(*ax_b)
  779. expected = fftconvolve(a, b, mode=mode, axes=axes)
  780. monkeypatch.setattr(signal._signaltools, 'fftconvolve',
  781. fftconvolve_err)
  782. out = oaconvolve(a, b, mode=mode, axes=axes)
  783. assert_array_almost_equal(out, expected)
  784. def test_empty(self):
  785. # Regression test for #1745: crashes with 0-length input.
  786. assert_(oaconvolve([], []).size == 0)
  787. assert_(oaconvolve([5, 6], []).size == 0)
  788. assert_(oaconvolve([], [7]).size == 0)
  789. def test_zero_rank(self):
  790. a = array(4967)
  791. b = array(3920)
  792. out = oaconvolve(a, b)
  793. assert_equal(out, a * b)
  794. def test_single_element(self):
  795. a = array([4967])
  796. b = array([3920])
  797. out = oaconvolve(a, b)
  798. assert_equal(out, a * b)
  799. class TestAllFreqConvolves:
  800. @pytest.mark.parametrize('convapproach',
  801. [fftconvolve, oaconvolve])
  802. def test_invalid_shapes(self, convapproach):
  803. a = np.arange(1, 7).reshape((2, 3))
  804. b = np.arange(-6, 0).reshape((3, 2))
  805. with assert_raises(ValueError,
  806. match="For 'valid' mode, one must be at least "
  807. "as large as the other in every dimension"):
  808. convapproach(a, b, mode='valid')
  809. @pytest.mark.parametrize('convapproach',
  810. [fftconvolve, oaconvolve])
  811. def test_invalid_shapes_axes(self, convapproach):
  812. a = np.zeros([5, 6, 2, 1])
  813. b = np.zeros([5, 6, 3, 1])
  814. with assert_raises(ValueError,
  815. match=r"incompatible shapes for in1 and in2:"
  816. r" \(5L?, 6L?, 2L?, 1L?\) and"
  817. r" \(5L?, 6L?, 3L?, 1L?\)"):
  818. convapproach(a, b, axes=[0, 1])
  819. @pytest.mark.parametrize('a,b',
  820. [([1], 2),
  821. (1, [2]),
  822. ([3], [[2]])])
  823. @pytest.mark.parametrize('convapproach',
  824. [fftconvolve, oaconvolve])
  825. def test_mismatched_dims(self, a, b, convapproach):
  826. with assert_raises(ValueError,
  827. match="in1 and in2 should have the same"
  828. " dimensionality"):
  829. convapproach(a, b)
  830. @pytest.mark.parametrize('convapproach',
  831. [fftconvolve, oaconvolve])
  832. def test_invalid_flags(self, convapproach):
  833. with assert_raises(ValueError,
  834. match="acceptable mode flags are 'valid',"
  835. " 'same', or 'full'"):
  836. convapproach([1], [2], mode='chips')
  837. with assert_raises(ValueError,
  838. match="when provided, axes cannot be empty"):
  839. convapproach([1], [2], axes=[])
  840. with assert_raises(ValueError, match="axes must be a scalar or "
  841. "iterable of integers"):
  842. convapproach([1], [2], axes=[[1, 2], [3, 4]])
  843. with assert_raises(ValueError, match="axes must be a scalar or "
  844. "iterable of integers"):
  845. convapproach([1], [2], axes=[1., 2., 3., 4.])
  846. with assert_raises(ValueError,
  847. match="axes exceeds dimensionality of input"):
  848. convapproach([1], [2], axes=[1])
  849. with assert_raises(ValueError,
  850. match="axes exceeds dimensionality of input"):
  851. convapproach([1], [2], axes=[-2])
  852. with assert_raises(ValueError,
  853. match="all axes must be unique"):
  854. convapproach([1], [2], axes=[0, 0])
  855. @pytest.mark.parametrize('dtype', [np.longfloat, np.longcomplex])
  856. def test_longdtype_input(self, dtype):
  857. x = np.random.random((27, 27)).astype(dtype)
  858. y = np.random.random((4, 4)).astype(dtype)
  859. if np.iscomplexobj(dtype()):
  860. x += .1j
  861. y -= .1j
  862. res = fftconvolve(x, y)
  863. assert_allclose(res, convolve(x, y, method='direct'))
  864. assert res.dtype == dtype
  865. class TestMedFilt:
  866. IN = [[50, 50, 50, 50, 50, 92, 18, 27, 65, 46],
  867. [50, 50, 50, 50, 50, 0, 72, 77, 68, 66],
  868. [50, 50, 50, 50, 50, 46, 47, 19, 64, 77],
  869. [50, 50, 50, 50, 50, 42, 15, 29, 95, 35],
  870. [50, 50, 50, 50, 50, 46, 34, 9, 21, 66],
  871. [70, 97, 28, 68, 78, 77, 61, 58, 71, 42],
  872. [64, 53, 44, 29, 68, 32, 19, 68, 24, 84],
  873. [3, 33, 53, 67, 1, 78, 74, 55, 12, 83],
  874. [7, 11, 46, 70, 60, 47, 24, 43, 61, 26],
  875. [32, 61, 88, 7, 39, 4, 92, 64, 45, 61]]
  876. OUT = [[0, 50, 50, 50, 42, 15, 15, 18, 27, 0],
  877. [0, 50, 50, 50, 50, 42, 19, 21, 29, 0],
  878. [50, 50, 50, 50, 50, 47, 34, 34, 46, 35],
  879. [50, 50, 50, 50, 50, 50, 42, 47, 64, 42],
  880. [50, 50, 50, 50, 50, 50, 46, 55, 64, 35],
  881. [33, 50, 50, 50, 50, 47, 46, 43, 55, 26],
  882. [32, 50, 50, 50, 50, 47, 46, 45, 55, 26],
  883. [7, 46, 50, 50, 47, 46, 46, 43, 45, 21],
  884. [0, 32, 33, 39, 32, 32, 43, 43, 43, 0],
  885. [0, 7, 11, 7, 4, 4, 19, 19, 24, 0]]
  886. KERNEL_SIZE = [7,3]
  887. def test_basic(self):
  888. d = signal.medfilt(self.IN, self.KERNEL_SIZE)
  889. e = signal.medfilt2d(np.array(self.IN, float), self.KERNEL_SIZE)
  890. assert_array_equal(d, self.OUT)
  891. assert_array_equal(d, e)
  892. @pytest.mark.parametrize('dtype', [np.ubyte, np.byte, np.ushort, np.short,
  893. np.uint, int, np.ulonglong, np.ulonglong,
  894. np.float32, np.float64, np.longdouble])
  895. def test_types(self, dtype):
  896. # volume input and output types match
  897. in_typed = np.array(self.IN, dtype=dtype)
  898. assert_equal(signal.medfilt(in_typed).dtype, dtype)
  899. assert_equal(signal.medfilt2d(in_typed).dtype, dtype)
  900. @pytest.mark.parametrize('dtype', [np.bool_, np.cfloat, np.cdouble,
  901. np.clongdouble, np.float16,])
  902. def test_invalid_dtypes(self, dtype):
  903. in_typed = np.array(self.IN, dtype=dtype)
  904. with pytest.raises(ValueError, match="order_filterND"):
  905. signal.medfilt(in_typed)
  906. with pytest.raises(ValueError, match="order_filterND"):
  907. signal.medfilt2d(in_typed)
  908. def test_none(self):
  909. # gh-1651, trac #1124. Ensure this does not segfault.
  910. with pytest.warns(UserWarning):
  911. assert_raises(TypeError, signal.medfilt, None)
  912. # Expand on this test to avoid a regression with possible contiguous
  913. # numpy arrays that have odd strides. The stride value below gets
  914. # us into wrong memory if used (but it does not need to be used)
  915. dummy = np.arange(10, dtype=np.float64)
  916. a = dummy[5:6]
  917. a.strides = 16
  918. assert_(signal.medfilt(a, 1) == 5.)
  919. def test_refcounting(self):
  920. # Check a refcounting-related crash
  921. a = Decimal(123)
  922. x = np.array([a, a], dtype=object)
  923. if hasattr(sys, 'getrefcount'):
  924. n = 2 * sys.getrefcount(a)
  925. else:
  926. n = 10
  927. # Shouldn't segfault:
  928. with pytest.warns(UserWarning):
  929. for j in range(n):
  930. signal.medfilt(x)
  931. if hasattr(sys, 'getrefcount'):
  932. assert_(sys.getrefcount(a) < n)
  933. assert_equal(x, [a, a])
  934. def test_object(self,):
  935. in_object = np.array(self.IN, dtype=object)
  936. out_object = np.array(self.OUT, dtype=object)
  937. assert_array_equal(signal.medfilt(in_object, self.KERNEL_SIZE),
  938. out_object)
  939. @pytest.mark.parametrize("dtype", [np.ubyte, np.float32, np.float64])
  940. def test_medfilt2d_parallel(self, dtype):
  941. in_typed = np.array(self.IN, dtype=dtype)
  942. expected = np.array(self.OUT, dtype=dtype)
  943. # This is used to simplify the indexing calculations.
  944. assert in_typed.shape == expected.shape
  945. # We'll do the calculation in four chunks. M1 and N1 are the dimensions
  946. # of the first output chunk. We have to extend the input by half the
  947. # kernel size to be able to calculate the full output chunk.
  948. M1 = expected.shape[0] // 2
  949. N1 = expected.shape[1] // 2
  950. offM = self.KERNEL_SIZE[0] // 2 + 1
  951. offN = self.KERNEL_SIZE[1] // 2 + 1
  952. def apply(chunk):
  953. # in = slice of in_typed to use.
  954. # sel = slice of output to crop it to the correct region.
  955. # out = slice of output array to store in.
  956. M, N = chunk
  957. if M == 0:
  958. Min = slice(0, M1 + offM)
  959. Msel = slice(0, -offM)
  960. Mout = slice(0, M1)
  961. else:
  962. Min = slice(M1 - offM, None)
  963. Msel = slice(offM, None)
  964. Mout = slice(M1, None)
  965. if N == 0:
  966. Nin = slice(0, N1 + offN)
  967. Nsel = slice(0, -offN)
  968. Nout = slice(0, N1)
  969. else:
  970. Nin = slice(N1 - offN, None)
  971. Nsel = slice(offN, None)
  972. Nout = slice(N1, None)
  973. # Do the calculation, but do not write to the output in the threads.
  974. chunk_data = in_typed[Min, Nin]
  975. med = signal.medfilt2d(chunk_data, self.KERNEL_SIZE)
  976. return med[Msel, Nsel], Mout, Nout
  977. # Give each chunk to a different thread.
  978. output = np.zeros_like(expected)
  979. with ThreadPoolExecutor(max_workers=4) as pool:
  980. chunks = {(0, 0), (0, 1), (1, 0), (1, 1)}
  981. futures = {pool.submit(apply, chunk) for chunk in chunks}
  982. # Store each result in the output as it arrives.
  983. for future in as_completed(futures):
  984. data, Mslice, Nslice = future.result()
  985. output[Mslice, Nslice] = data
  986. assert_array_equal(output, expected)
  987. class TestWiener:
  988. def test_basic(self):
  989. g = array([[5, 6, 4, 3],
  990. [3, 5, 6, 2],
  991. [2, 3, 5, 6],
  992. [1, 6, 9, 7]], 'd')
  993. h = array([[2.16374269, 3.2222222222, 2.8888888889, 1.6666666667],
  994. [2.666666667, 4.33333333333, 4.44444444444, 2.8888888888],
  995. [2.222222222, 4.4444444444, 5.4444444444, 4.801066874837],
  996. [1.33333333333, 3.92735042735, 6.0712560386, 5.0404040404]])
  997. assert_array_almost_equal(signal.wiener(g), h, decimal=6)
  998. assert_array_almost_equal(signal.wiener(g, mysize=3), h, decimal=6)
  999. padtype_options = ["mean", "median", "minimum", "maximum", "line"]
  1000. padtype_options += _upfirdn_modes
  1001. class TestResample:
  1002. def test_basic(self):
  1003. # Some basic tests
  1004. # Regression test for issue #3603.
  1005. # window.shape must equal to sig.shape[0]
  1006. sig = np.arange(128)
  1007. num = 256
  1008. win = signal.get_window(('kaiser', 8.0), 160)
  1009. assert_raises(ValueError, signal.resample, sig, num, window=win)
  1010. # Other degenerate conditions
  1011. assert_raises(ValueError, signal.resample_poly, sig, 'yo', 1)
  1012. assert_raises(ValueError, signal.resample_poly, sig, 1, 0)
  1013. assert_raises(ValueError, signal.resample_poly, sig, 2, 1, padtype='')
  1014. assert_raises(ValueError, signal.resample_poly, sig, 2, 1,
  1015. padtype='mean', cval=10)
  1016. # test for issue #6505 - should not modify window.shape when axis ≠ 0
  1017. sig2 = np.tile(np.arange(160), (2, 1))
  1018. signal.resample(sig2, num, axis=-1, window=win)
  1019. assert_(win.shape == (160,))
  1020. @pytest.mark.parametrize('window', (None, 'hamming'))
  1021. @pytest.mark.parametrize('N', (20, 19))
  1022. @pytest.mark.parametrize('num', (100, 101, 10, 11))
  1023. def test_rfft(self, N, num, window):
  1024. # Make sure the speed up using rfft gives the same result as the normal
  1025. # way using fft
  1026. x = np.linspace(0, 10, N, endpoint=False)
  1027. y = np.cos(-x**2/6.0)
  1028. assert_allclose(signal.resample(y, num, window=window),
  1029. signal.resample(y + 0j, num, window=window).real)
  1030. y = np.array([np.cos(-x**2/6.0), np.sin(-x**2/6.0)])
  1031. y_complex = y + 0j
  1032. assert_allclose(
  1033. signal.resample(y, num, axis=1, window=window),
  1034. signal.resample(y_complex, num, axis=1, window=window).real,
  1035. atol=1e-9)
  1036. def test_input_domain(self):
  1037. # Test if both input domain modes produce the same results.
  1038. tsig = np.arange(256) + 0j
  1039. fsig = fft(tsig)
  1040. num = 256
  1041. assert_allclose(
  1042. signal.resample(fsig, num, domain='freq'),
  1043. signal.resample(tsig, num, domain='time'),
  1044. atol=1e-9)
  1045. @pytest.mark.parametrize('nx', (1, 2, 3, 5, 8))
  1046. @pytest.mark.parametrize('ny', (1, 2, 3, 5, 8))
  1047. @pytest.mark.parametrize('dtype', ('float', 'complex'))
  1048. def test_dc(self, nx, ny, dtype):
  1049. x = np.array([1] * nx, dtype)
  1050. y = signal.resample(x, ny)
  1051. assert_allclose(y, [1] * ny)
  1052. @pytest.mark.parametrize('padtype', padtype_options)
  1053. def test_mutable_window(self, padtype):
  1054. # Test that a mutable window is not modified
  1055. impulse = np.zeros(3)
  1056. window = np.random.RandomState(0).randn(2)
  1057. window_orig = window.copy()
  1058. signal.resample_poly(impulse, 5, 1, window=window, padtype=padtype)
  1059. assert_array_equal(window, window_orig)
  1060. @pytest.mark.parametrize('padtype', padtype_options)
  1061. def test_output_float32(self, padtype):
  1062. # Test that float32 inputs yield a float32 output
  1063. x = np.arange(10, dtype=np.float32)
  1064. h = np.array([1, 1, 1], dtype=np.float32)
  1065. y = signal.resample_poly(x, 1, 2, window=h, padtype=padtype)
  1066. assert y.dtype == np.float32
  1067. @pytest.mark.parametrize('padtype', padtype_options)
  1068. @pytest.mark.parametrize('dtype', [np.float32, np.float64])
  1069. def test_output_match_dtype(self, padtype, dtype):
  1070. # Test that the dtype of x is preserved per issue #14733
  1071. x = np.arange(10, dtype=dtype)
  1072. y = signal.resample_poly(x, 1, 2, padtype=padtype)
  1073. assert y.dtype == x.dtype
  1074. @pytest.mark.parametrize(
  1075. "method, ext, padtype",
  1076. [("fft", False, None)]
  1077. + list(
  1078. product(
  1079. ["polyphase"], [False, True], padtype_options,
  1080. )
  1081. ),
  1082. )
  1083. def test_resample_methods(self, method, ext, padtype):
  1084. # Test resampling of sinusoids and random noise (1-sec)
  1085. rate = 100
  1086. rates_to = [49, 50, 51, 99, 100, 101, 199, 200, 201]
  1087. # Sinusoids, windowed to avoid edge artifacts
  1088. t = np.arange(rate) / float(rate)
  1089. freqs = np.array((1., 10., 40.))[:, np.newaxis]
  1090. x = np.sin(2 * np.pi * freqs * t) * hann(rate)
  1091. for rate_to in rates_to:
  1092. t_to = np.arange(rate_to) / float(rate_to)
  1093. y_tos = np.sin(2 * np.pi * freqs * t_to) * hann(rate_to)
  1094. if method == 'fft':
  1095. y_resamps = signal.resample(x, rate_to, axis=-1)
  1096. else:
  1097. if ext and rate_to != rate:
  1098. # Match default window design
  1099. g = gcd(rate_to, rate)
  1100. up = rate_to // g
  1101. down = rate // g
  1102. max_rate = max(up, down)
  1103. f_c = 1. / max_rate
  1104. half_len = 10 * max_rate
  1105. window = signal.firwin(2 * half_len + 1, f_c,
  1106. window=('kaiser', 5.0))
  1107. polyargs = {'window': window, 'padtype': padtype}
  1108. else:
  1109. polyargs = {'padtype': padtype}
  1110. y_resamps = signal.resample_poly(x, rate_to, rate, axis=-1,
  1111. **polyargs)
  1112. for y_to, y_resamp, freq in zip(y_tos, y_resamps, freqs):
  1113. if freq >= 0.5 * rate_to:
  1114. y_to.fill(0.) # mostly low-passed away
  1115. if padtype in ['minimum', 'maximum']:
  1116. assert_allclose(y_resamp, y_to, atol=3e-1)
  1117. else:
  1118. assert_allclose(y_resamp, y_to, atol=1e-3)
  1119. else:
  1120. assert_array_equal(y_to.shape, y_resamp.shape)
  1121. corr = np.corrcoef(y_to, y_resamp)[0, 1]
  1122. assert_(corr > 0.99, msg=(corr, rate, rate_to))
  1123. # Random data
  1124. rng = np.random.RandomState(0)
  1125. x = hann(rate) * np.cumsum(rng.randn(rate)) # low-pass, wind
  1126. for rate_to in rates_to:
  1127. # random data
  1128. t_to = np.arange(rate_to) / float(rate_to)
  1129. y_to = np.interp(t_to, t, x)
  1130. if method == 'fft':
  1131. y_resamp = signal.resample(x, rate_to)
  1132. else:
  1133. y_resamp = signal.resample_poly(x, rate_to, rate,
  1134. padtype=padtype)
  1135. assert_array_equal(y_to.shape, y_resamp.shape)
  1136. corr = np.corrcoef(y_to, y_resamp)[0, 1]
  1137. assert_(corr > 0.99, msg=corr)
  1138. # More tests of fft method (Master 0.18.1 fails these)
  1139. if method == 'fft':
  1140. x1 = np.array([1.+0.j, 0.+0.j])
  1141. y1_test = signal.resample(x1, 4)
  1142. # upsampling a complex array
  1143. y1_true = np.array([1.+0.j, 0.5+0.j, 0.+0.j, 0.5+0.j])
  1144. assert_allclose(y1_test, y1_true, atol=1e-12)
  1145. x2 = np.array([1., 0.5, 0., 0.5])
  1146. y2_test = signal.resample(x2, 2) # downsampling a real array
  1147. y2_true = np.array([1., 0.])
  1148. assert_allclose(y2_test, y2_true, atol=1e-12)
  1149. def test_poly_vs_filtfilt(self):
  1150. # Check that up=1.0 gives same answer as filtfilt + slicing
  1151. random_state = np.random.RandomState(17)
  1152. try_types = (int, np.float32, np.complex64, float, complex)
  1153. size = 10000
  1154. down_factors = [2, 11, 79]
  1155. for dtype in try_types:
  1156. x = random_state.randn(size).astype(dtype)
  1157. if dtype in (np.complex64, np.complex128):
  1158. x += 1j * random_state.randn(size)
  1159. # resample_poly assumes zeros outside of signl, whereas filtfilt
  1160. # can only constant-pad. Make them equivalent:
  1161. x[0] = 0
  1162. x[-1] = 0
  1163. for down in down_factors:
  1164. h = signal.firwin(31, 1. / down, window='hamming')
  1165. yf = filtfilt(h, 1.0, x, padtype='constant')[::down]
  1166. # Need to pass convolved version of filter to resample_poly,
  1167. # since filtfilt does forward and backward, but resample_poly
  1168. # only goes forward
  1169. hc = convolve(h, h[::-1])
  1170. y = signal.resample_poly(x, 1, down, window=hc)
  1171. assert_allclose(yf, y, atol=1e-7, rtol=1e-7)
  1172. def test_correlate1d(self):
  1173. for down in [2, 4]:
  1174. for nx in range(1, 40, down):
  1175. for nweights in (32, 33):
  1176. x = np.random.random((nx,))
  1177. weights = np.random.random((nweights,))
  1178. y_g = correlate1d(x, weights[::-1], mode='constant')
  1179. y_s = signal.resample_poly(
  1180. x, up=1, down=down, window=weights)
  1181. assert_allclose(y_g[::down], y_s)
  1182. class TestCSpline1DEval:
  1183. def test_basic(self):
  1184. y = array([1, 2, 3, 4, 3, 2, 1, 2, 3.0])
  1185. x = arange(len(y))
  1186. dx = x[1] - x[0]
  1187. cj = signal.cspline1d(y)
  1188. x2 = arange(len(y) * 10.0) / 10.0
  1189. y2 = signal.cspline1d_eval(cj, x2, dx=dx, x0=x[0])
  1190. # make sure interpolated values are on knot points
  1191. assert_array_almost_equal(y2[::10], y, decimal=5)
  1192. def test_complex(self):
  1193. # create some smoothly varying complex signal to interpolate
  1194. x = np.arange(2)
  1195. y = np.zeros(x.shape, dtype=np.complex64)
  1196. T = 10.0
  1197. f = 1.0 / T
  1198. y = np.exp(2.0J * np.pi * f * x)
  1199. # get the cspline transform
  1200. cy = signal.cspline1d(y)
  1201. # determine new test x value and interpolate
  1202. xnew = np.array([0.5])
  1203. ynew = signal.cspline1d_eval(cy, xnew)
  1204. assert_equal(ynew.dtype, y.dtype)
  1205. class TestOrderFilt:
  1206. def test_basic(self):
  1207. assert_array_equal(signal.order_filter([1, 2, 3], [1, 0, 1], 1),
  1208. [2, 3, 2])
  1209. class _TestLinearFilter:
  1210. def generate(self, shape):
  1211. x = np.linspace(0, np.prod(shape) - 1, np.prod(shape)).reshape(shape)
  1212. return self.convert_dtype(x)
  1213. def convert_dtype(self, arr):
  1214. if self.dtype == np.dtype('O'):
  1215. arr = np.asarray(arr)
  1216. out = np.empty(arr.shape, self.dtype)
  1217. iter = np.nditer([arr, out], ['refs_ok','zerosize_ok'],
  1218. [['readonly'],['writeonly']])
  1219. for x, y in iter:
  1220. y[...] = self.type(x[()])
  1221. return out
  1222. else:
  1223. return np.array(arr, self.dtype, copy=False)
  1224. def test_rank_1_IIR(self):
  1225. x = self.generate((6,))
  1226. b = self.convert_dtype([1, -1])
  1227. a = self.convert_dtype([0.5, -0.5])
  1228. y_r = self.convert_dtype([0, 2, 4, 6, 8, 10.])
  1229. assert_array_almost_equal(lfilter(b, a, x), y_r)
  1230. def test_rank_1_FIR(self):
  1231. x = self.generate((6,))
  1232. b = self.convert_dtype([1, 1])
  1233. a = self.convert_dtype([1])
  1234. y_r = self.convert_dtype([0, 1, 3, 5, 7, 9.])
  1235. assert_array_almost_equal(lfilter(b, a, x), y_r)
  1236. def test_rank_1_IIR_init_cond(self):
  1237. x = self.generate((6,))
  1238. b = self.convert_dtype([1, 0, -1])
  1239. a = self.convert_dtype([0.5, -0.5])
  1240. zi = self.convert_dtype([1, 2])
  1241. y_r = self.convert_dtype([1, 5, 9, 13, 17, 21])
  1242. zf_r = self.convert_dtype([13, -10])
  1243. y, zf = lfilter(b, a, x, zi=zi)
  1244. assert_array_almost_equal(y, y_r)
  1245. assert_array_almost_equal(zf, zf_r)
  1246. def test_rank_1_FIR_init_cond(self):
  1247. x = self.generate((6,))
  1248. b = self.convert_dtype([1, 1, 1])
  1249. a = self.convert_dtype([1])
  1250. zi = self.convert_dtype([1, 1])
  1251. y_r = self.convert_dtype([1, 2, 3, 6, 9, 12.])
  1252. zf_r = self.convert_dtype([9, 5])
  1253. y, zf = lfilter(b, a, x, zi=zi)
  1254. assert_array_almost_equal(y, y_r)
  1255. assert_array_almost_equal(zf, zf_r)
  1256. def test_rank_2_IIR_axis_0(self):
  1257. x = self.generate((4, 3))
  1258. b = self.convert_dtype([1, -1])
  1259. a = self.convert_dtype([0.5, 0.5])
  1260. y_r2_a0 = self.convert_dtype([[0, 2, 4], [6, 4, 2], [0, 2, 4],
  1261. [6, 4, 2]])
  1262. y = lfilter(b, a, x, axis=0)
  1263. assert_array_almost_equal(y_r2_a0, y)
  1264. def test_rank_2_IIR_axis_1(self):
  1265. x = self.generate((4, 3))
  1266. b = self.convert_dtype([1, -1])
  1267. a = self.convert_dtype([0.5, 0.5])
  1268. y_r2_a1 = self.convert_dtype([[0, 2, 0], [6, -4, 6], [12, -10, 12],
  1269. [18, -16, 18]])
  1270. y = lfilter(b, a, x, axis=1)
  1271. assert_array_almost_equal(y_r2_a1, y)
  1272. def test_rank_2_IIR_axis_0_init_cond(self):
  1273. x = self.generate((4, 3))
  1274. b = self.convert_dtype([1, -1])
  1275. a = self.convert_dtype([0.5, 0.5])
  1276. zi = self.convert_dtype(np.ones((4,1)))
  1277. y_r2_a0_1 = self.convert_dtype([[1, 1, 1], [7, -5, 7], [13, -11, 13],
  1278. [19, -17, 19]])
  1279. zf_r = self.convert_dtype([-5, -17, -29, -41])[:, np.newaxis]
  1280. y, zf = lfilter(b, a, x, axis=1, zi=zi)
  1281. assert_array_almost_equal(y_r2_a0_1, y)
  1282. assert_array_almost_equal(zf, zf_r)
  1283. def test_rank_2_IIR_axis_1_init_cond(self):
  1284. x = self.generate((4,3))
  1285. b = self.convert_dtype([1, -1])
  1286. a = self.convert_dtype([0.5, 0.5])
  1287. zi = self.convert_dtype(np.ones((1,3)))
  1288. y_r2_a0_0 = self.convert_dtype([[1, 3, 5], [5, 3, 1],
  1289. [1, 3, 5], [5, 3, 1]])
  1290. zf_r = self.convert_dtype([[-23, -23, -23]])
  1291. y, zf = lfilter(b, a, x, axis=0, zi=zi)
  1292. assert_array_almost_equal(y_r2_a0_0, y)
  1293. assert_array_almost_equal(zf, zf_r)
  1294. def test_rank_3_IIR(self):
  1295. x = self.generate((4, 3, 2))
  1296. b = self.convert_dtype([1, -1])
  1297. a = self.convert_dtype([0.5, 0.5])
  1298. for axis in range(x.ndim):
  1299. y = lfilter(b, a, x, axis)
  1300. y_r = np.apply_along_axis(lambda w: lfilter(b, a, w), axis, x)
  1301. assert_array_almost_equal(y, y_r)
  1302. def test_rank_3_IIR_init_cond(self):
  1303. x = self.generate((4, 3, 2))
  1304. b = self.convert_dtype([1, -1])
  1305. a = self.convert_dtype([0.5, 0.5])
  1306. for axis in range(x.ndim):
  1307. zi_shape = list(x.shape)
  1308. zi_shape[axis] = 1
  1309. zi = self.convert_dtype(np.ones(zi_shape))
  1310. zi1 = self.convert_dtype([1])
  1311. y, zf = lfilter(b, a, x, axis, zi)
  1312. lf0 = lambda w: lfilter(b, a, w, zi=zi1)[0]
  1313. lf1 = lambda w: lfilter(b, a, w, zi=zi1)[1]
  1314. y_r = np.apply_along_axis(lf0, axis, x)
  1315. zf_r = np.apply_along_axis(lf1, axis, x)
  1316. assert_array_almost_equal(y, y_r)
  1317. assert_array_almost_equal(zf, zf_r)
  1318. def test_rank_3_FIR(self):
  1319. x = self.generate((4, 3, 2))
  1320. b = self.convert_dtype([1, 0, -1])
  1321. a = self.convert_dtype([1])
  1322. for axis in range(x.ndim):
  1323. y = lfilter(b, a, x, axis)
  1324. y_r = np.apply_along_axis(lambda w: lfilter(b, a, w), axis, x)
  1325. assert_array_almost_equal(y, y_r)
  1326. def test_rank_3_FIR_init_cond(self):
  1327. x = self.generate((4, 3, 2))
  1328. b = self.convert_dtype([1, 0, -1])
  1329. a = self.convert_dtype([1])
  1330. for axis in range(x.ndim):
  1331. zi_shape = list(x.shape)
  1332. zi_shape[axis] = 2
  1333. zi = self.convert_dtype(np.ones(zi_shape))
  1334. zi1 = self.convert_dtype([1, 1])
  1335. y, zf = lfilter(b, a, x, axis, zi)
  1336. lf0 = lambda w: lfilter(b, a, w, zi=zi1)[0]
  1337. lf1 = lambda w: lfilter(b, a, w, zi=zi1)[1]
  1338. y_r = np.apply_along_axis(lf0, axis, x)
  1339. zf_r = np.apply_along_axis(lf1, axis, x)
  1340. assert_array_almost_equal(y, y_r)
  1341. assert_array_almost_equal(zf, zf_r)
  1342. def test_zi_pseudobroadcast(self):
  1343. x = self.generate((4, 5, 20))
  1344. b,a = signal.butter(8, 0.2, output='ba')
  1345. b = self.convert_dtype(b)
  1346. a = self.convert_dtype(a)
  1347. zi_size = b.shape[0] - 1
  1348. # lfilter requires x.ndim == zi.ndim exactly. However, zi can have
  1349. # length 1 dimensions.
  1350. zi_full = self.convert_dtype(np.ones((4, 5, zi_size)))
  1351. zi_sing = self.convert_dtype(np.ones((1, 1, zi_size)))
  1352. y_full, zf_full = lfilter(b, a, x, zi=zi_full)
  1353. y_sing, zf_sing = lfilter(b, a, x, zi=zi_sing)
  1354. assert_array_almost_equal(y_sing, y_full)
  1355. assert_array_almost_equal(zf_full, zf_sing)
  1356. # lfilter does not prepend ones
  1357. assert_raises(ValueError, lfilter, b, a, x, -1, np.ones(zi_size))
  1358. def test_scalar_a(self):
  1359. # a can be a scalar.
  1360. x = self.generate(6)
  1361. b = self.convert_dtype([1, 0, -1])
  1362. a = self.convert_dtype([1])
  1363. y_r = self.convert_dtype([0, 1, 2, 2, 2, 2])
  1364. y = lfilter(b, a[0], x)
  1365. assert_array_almost_equal(y, y_r)
  1366. def test_zi_some_singleton_dims(self):
  1367. # lfilter doesn't really broadcast (no prepending of 1's). But does
  1368. # do singleton expansion if x and zi have the same ndim. This was
  1369. # broken only if a subset of the axes were singletons (gh-4681).
  1370. x = self.convert_dtype(np.zeros((3,2,5), 'l'))
  1371. b = self.convert_dtype(np.ones(5, 'l'))
  1372. a = self.convert_dtype(np.array([1,0,0]))
  1373. zi = np.ones((3,1,4), 'l')
  1374. zi[1,:,:] *= 2
  1375. zi[2,:,:] *= 3
  1376. zi = self.convert_dtype(zi)
  1377. zf_expected = self.convert_dtype(np.zeros((3,2,4), 'l'))
  1378. y_expected = np.zeros((3,2,5), 'l')
  1379. y_expected[:,:,:4] = [[[1]], [[2]], [[3]]]
  1380. y_expected = self.convert_dtype(y_expected)
  1381. # IIR
  1382. y_iir, zf_iir = lfilter(b, a, x, -1, zi)
  1383. assert_array_almost_equal(y_iir, y_expected)
  1384. assert_array_almost_equal(zf_iir, zf_expected)
  1385. # FIR
  1386. y_fir, zf_fir = lfilter(b, a[0], x, -1, zi)
  1387. assert_array_almost_equal(y_fir, y_expected)
  1388. assert_array_almost_equal(zf_fir, zf_expected)
  1389. def base_bad_size_zi(self, b, a, x, axis, zi):
  1390. b = self.convert_dtype(b)
  1391. a = self.convert_dtype(a)
  1392. x = self.convert_dtype(x)
  1393. zi = self.convert_dtype(zi)
  1394. assert_raises(ValueError, lfilter, b, a, x, axis, zi)
  1395. def test_bad_size_zi(self):
  1396. # rank 1
  1397. x1 = np.arange(6)
  1398. self.base_bad_size_zi([1], [1], x1, -1, [1])
  1399. self.base_bad_size_zi([1, 1], [1], x1, -1, [0, 1])
  1400. self.base_bad_size_zi([1, 1], [1], x1, -1, [[0]])
  1401. self.base_bad_size_zi([1, 1], [1], x1, -1, [0, 1, 2])
  1402. self.base_bad_size_zi([1, 1, 1], [1], x1, -1, [[0]])
  1403. self.base_bad_size_zi([1, 1, 1], [1], x1, -1, [0, 1, 2])
  1404. self.base_bad_size_zi([1], [1, 1], x1, -1, [0, 1])
  1405. self.base_bad_size_zi([1], [1, 1], x1, -1, [[0]])
  1406. self.base_bad_size_zi([1], [1, 1], x1, -1, [0, 1, 2])
  1407. self.base_bad_size_zi([1, 1, 1], [1, 1], x1, -1, [0])
  1408. self.base_bad_size_zi([1, 1, 1], [1, 1], x1, -1, [[0], [1]])
  1409. self.base_bad_size_zi([1, 1, 1], [1, 1], x1, -1, [0, 1, 2])
  1410. self.base_bad_size_zi([1, 1, 1], [1, 1], x1, -1, [0, 1, 2, 3])
  1411. self.base_bad_size_zi([1, 1], [1, 1, 1], x1, -1, [0])
  1412. self.base_bad_size_zi([1, 1], [1, 1, 1], x1, -1, [[0], [1]])
  1413. self.base_bad_size_zi([1, 1], [1, 1, 1], x1, -1, [0, 1, 2])
  1414. self.base_bad_size_zi([1, 1], [1, 1, 1], x1, -1, [0, 1, 2, 3])
  1415. # rank 2
  1416. x2 = np.arange(12).reshape((4,3))
  1417. # for axis=0 zi.shape should == (max(len(a),len(b))-1, 3)
  1418. self.base_bad_size_zi([1], [1], x2, 0, [0])
  1419. # for each of these there are 5 cases tested (in this order):
  1420. # 1. not deep enough, right # elements
  1421. # 2. too deep, right # elements
  1422. # 3. right depth, right # elements, transposed
  1423. # 4. right depth, too few elements
  1424. # 5. right depth, too many elements
  1425. self.base_bad_size_zi([1, 1], [1], x2, 0, [0,1,2])
  1426. self.base_bad_size_zi([1, 1], [1], x2, 0, [[[0,1,2]]])
  1427. self.base_bad_size_zi([1, 1], [1], x2, 0, [[0], [1], [2]])
  1428. self.base_bad_size_zi([1, 1], [1], x2, 0, [[0,1]])
  1429. self.base_bad_size_zi([1, 1], [1], x2, 0, [[0,1,2,3]])
  1430. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [0,1,2,3,4,5])
  1431. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [[[0,1,2],[3,4,5]]])
  1432. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [[0,1],[2,3],[4,5]])
  1433. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [[0,1],[2,3]])
  1434. self.base_bad_size_zi([1, 1, 1], [1], x2, 0, [[0,1,2,3],[4,5,6,7]])
  1435. self.base_bad_size_zi([1], [1, 1], x2, 0, [0,1,2])
  1436. self.base_bad_size_zi([1], [1, 1], x2, 0, [[[0,1,2]]])
  1437. self.base_bad_size_zi([1], [1, 1], x2, 0, [[0], [1], [2]])
  1438. self.base_bad_size_zi([1], [1, 1], x2, 0, [[0,1]])
  1439. self.base_bad_size_zi([1], [1, 1], x2, 0, [[0,1,2,3]])
  1440. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [0,1,2,3,4,5])
  1441. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [[[0,1,2],[3,4,5]]])
  1442. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [[0,1],[2,3],[4,5]])
  1443. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [[0,1],[2,3]])
  1444. self.base_bad_size_zi([1], [1, 1, 1], x2, 0, [[0,1,2,3],[4,5,6,7]])
  1445. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [0,1,2,3,4,5])
  1446. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [[[0,1,2],[3,4,5]]])
  1447. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [[0,1],[2,3],[4,5]])
  1448. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [[0,1],[2,3]])
  1449. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 0, [[0,1,2,3],[4,5,6,7]])
  1450. # for axis=1 zi.shape should == (4, max(len(a),len(b))-1)
  1451. self.base_bad_size_zi([1], [1], x2, 1, [0])
  1452. self.base_bad_size_zi([1, 1], [1], x2, 1, [0,1,2,3])
  1453. self.base_bad_size_zi([1, 1], [1], x2, 1, [[[0],[1],[2],[3]]])
  1454. self.base_bad_size_zi([1, 1], [1], x2, 1, [[0, 1, 2, 3]])
  1455. self.base_bad_size_zi([1, 1], [1], x2, 1, [[0],[1],[2]])
  1456. self.base_bad_size_zi([1, 1], [1], x2, 1, [[0],[1],[2],[3],[4]])
  1457. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [0,1,2,3,4,5,6,7])
  1458. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [[[0,1],[2,3],[4,5],[6,7]]])
  1459. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [[0,1,2,3],[4,5,6,7]])
  1460. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [[0,1],[2,3],[4,5]])
  1461. self.base_bad_size_zi([1, 1, 1], [1], x2, 1, [[0,1],[2,3],[4,5],[6,7],[8,9]])
  1462. self.base_bad_size_zi([1], [1, 1], x2, 1, [0,1,2,3])
  1463. self.base_bad_size_zi([1], [1, 1], x2, 1, [[[0],[1],[2],[3]]])
  1464. self.base_bad_size_zi([1], [1, 1], x2, 1, [[0, 1, 2, 3]])
  1465. self.base_bad_size_zi([1], [1, 1], x2, 1, [[0],[1],[2]])
  1466. self.base_bad_size_zi([1], [1, 1], x2, 1, [[0],[1],[2],[3],[4]])
  1467. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [0,1,2,3,4,5,6,7])
  1468. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [[[0,1],[2,3],[4,5],[6,7]]])
  1469. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [[0,1,2,3],[4,5,6,7]])
  1470. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [[0,1],[2,3],[4,5]])
  1471. self.base_bad_size_zi([1], [1, 1, 1], x2, 1, [[0,1],[2,3],[4,5],[6,7],[8,9]])
  1472. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1, [0,1,2,3,4,5,6,7])
  1473. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1, [[[0,1],[2,3],[4,5],[6,7]]])
  1474. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1, [[0,1,2,3],[4,5,6,7]])
  1475. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1, [[0,1],[2,3],[4,5]])
  1476. self.base_bad_size_zi([1, 1, 1], [1, 1], x2, 1, [[0,1],[2,3],[4,5],[6,7],[8,9]])
  1477. def test_empty_zi(self):
  1478. # Regression test for #880: empty array for zi crashes.
  1479. x = self.generate((5,))
  1480. a = self.convert_dtype([1])
  1481. b = self.convert_dtype([1])
  1482. zi = self.convert_dtype([])
  1483. y, zf = lfilter(b, a, x, zi=zi)
  1484. assert_array_almost_equal(y, x)
  1485. assert_equal(zf.dtype, self.dtype)
  1486. assert_equal(zf.size, 0)
  1487. def test_lfiltic_bad_zi(self):
  1488. # Regression test for #3699: bad initial conditions
  1489. a = self.convert_dtype([1])
  1490. b = self.convert_dtype([1])
  1491. # "y" sets the datatype of zi, so it truncates if int
  1492. zi = lfiltic(b, a, [1., 0])
  1493. zi_1 = lfiltic(b, a, [1, 0])
  1494. zi_2 = lfiltic(b, a, [True, False])
  1495. assert_array_equal(zi, zi_1)
  1496. assert_array_equal(zi, zi_2)
  1497. def test_short_x_FIR(self):
  1498. # regression test for #5116
  1499. # x shorter than b, with non None zi fails
  1500. a = self.convert_dtype([1])
  1501. b = self.convert_dtype([1, 0, -1])
  1502. zi = self.convert_dtype([2, 7])
  1503. x = self.convert_dtype([72])
  1504. ye = self.convert_dtype([74])
  1505. zfe = self.convert_dtype([7, -72])
  1506. y, zf = lfilter(b, a, x, zi=zi)
  1507. assert_array_almost_equal(y, ye)
  1508. assert_array_almost_equal(zf, zfe)
  1509. def test_short_x_IIR(self):
  1510. # regression test for #5116
  1511. # x shorter than b, with non None zi fails
  1512. a = self.convert_dtype([1, 1])
  1513. b = self.convert_dtype([1, 0, -1])
  1514. zi = self.convert_dtype([2, 7])
  1515. x = self.convert_dtype([72])
  1516. ye = self.convert_dtype([74])
  1517. zfe = self.convert_dtype([-67, -72])
  1518. y, zf = lfilter(b, a, x, zi=zi)
  1519. assert_array_almost_equal(y, ye)
  1520. assert_array_almost_equal(zf, zfe)
  1521. def test_do_not_modify_a_b_IIR(self):
  1522. x = self.generate((6,))
  1523. b = self.convert_dtype([1, -1])
  1524. b0 = b.copy()
  1525. a = self.convert_dtype([0.5, -0.5])
  1526. a0 = a.copy()
  1527. y_r = self.convert_dtype([0, 2, 4, 6, 8, 10.])
  1528. y_f = lfilter(b, a, x)
  1529. assert_array_almost_equal(y_f, y_r)
  1530. assert_equal(b, b0)
  1531. assert_equal(a, a0)
  1532. def test_do_not_modify_a_b_FIR(self):
  1533. x = self.generate((6,))
  1534. b = self.convert_dtype([1, 0, 1])
  1535. b0 = b.copy()
  1536. a = self.convert_dtype([2])
  1537. a0 = a.copy()
  1538. y_r = self.convert_dtype([0, 0.5, 1, 2, 3, 4.])
  1539. y_f = lfilter(b, a, x)
  1540. assert_array_almost_equal(y_f, y_r)
  1541. assert_equal(b, b0)
  1542. assert_equal(a, a0)
  1543. class TestLinearFilterFloat32(_TestLinearFilter):
  1544. dtype = np.dtype('f')
  1545. class TestLinearFilterFloat64(_TestLinearFilter):
  1546. dtype = np.dtype('d')
  1547. class TestLinearFilterFloatExtended(_TestLinearFilter):
  1548. dtype = np.dtype('g')
  1549. class TestLinearFilterComplex64(_TestLinearFilter):
  1550. dtype = np.dtype('F')
  1551. class TestLinearFilterComplex128(_TestLinearFilter):
  1552. dtype = np.dtype('D')
  1553. class TestLinearFilterComplexExtended(_TestLinearFilter):
  1554. dtype = np.dtype('G')
  1555. class TestLinearFilterDecimal(_TestLinearFilter):
  1556. dtype = np.dtype('O')
  1557. def type(self, x):
  1558. return Decimal(str(x))
  1559. class TestLinearFilterObject(_TestLinearFilter):
  1560. dtype = np.dtype('O')
  1561. type = float
  1562. def test_lfilter_bad_object():
  1563. # lfilter: object arrays with non-numeric objects raise TypeError.
  1564. # Regression test for ticket #1452.
  1565. assert_raises(TypeError, lfilter, [1.0], [1.0], [1.0, None, 2.0])
  1566. assert_raises(TypeError, lfilter, [1.0], [None], [1.0, 2.0, 3.0])
  1567. assert_raises(TypeError, lfilter, [None], [1.0], [1.0, 2.0, 3.0])
  1568. def test_lfilter_notimplemented_input():
  1569. # Should not crash, gh-7991
  1570. assert_raises(NotImplementedError, lfilter, [2,3], [4,5], [1,2,3,4,5])
  1571. @pytest.mark.parametrize('dt', [np.ubyte, np.byte, np.ushort, np.short,
  1572. np.uint, int, np.ulonglong, np.ulonglong,
  1573. np.float32, np.float64, np.longdouble,
  1574. Decimal])
  1575. class TestCorrelateReal:
  1576. def _setup_rank1(self, dt):
  1577. a = np.linspace(0, 3, 4).astype(dt)
  1578. b = np.linspace(1, 2, 2).astype(dt)
  1579. y_r = np.array([0, 2, 5, 8, 3]).astype(dt)
  1580. return a, b, y_r
  1581. def equal_tolerance(self, res_dt):
  1582. # default value of keyword
  1583. decimal = 6
  1584. try:
  1585. dt_info = np.finfo(res_dt)
  1586. if hasattr(dt_info, 'resolution'):
  1587. decimal = int(-0.5*np.log10(dt_info.resolution))
  1588. except Exception:
  1589. pass
  1590. return decimal
  1591. def equal_tolerance_fft(self, res_dt):
  1592. # FFT implementations convert longdouble arguments down to
  1593. # double so don't expect better precision, see gh-9520
  1594. if res_dt == np.longdouble:
  1595. return self.equal_tolerance(np.double)
  1596. else:
  1597. return self.equal_tolerance(res_dt)
  1598. def test_method(self, dt):
  1599. if dt == Decimal:
  1600. method = choose_conv_method([Decimal(4)], [Decimal(3)])
  1601. assert_equal(method, 'direct')
  1602. else:
  1603. a, b, y_r = self._setup_rank3(dt)
  1604. y_fft = correlate(a, b, method='fft')
  1605. y_direct = correlate(a, b, method='direct')
  1606. assert_array_almost_equal(y_r, y_fft, decimal=self.equal_tolerance_fft(y_fft.dtype))
  1607. assert_array_almost_equal(y_r, y_direct, decimal=self.equal_tolerance(y_direct.dtype))
  1608. assert_equal(y_fft.dtype, dt)
  1609. assert_equal(y_direct.dtype, dt)
  1610. def test_rank1_valid(self, dt):
  1611. a, b, y_r = self._setup_rank1(dt)
  1612. y = correlate(a, b, 'valid')
  1613. assert_array_almost_equal(y, y_r[1:4])
  1614. assert_equal(y.dtype, dt)
  1615. # See gh-5897
  1616. y = correlate(b, a, 'valid')
  1617. assert_array_almost_equal(y, y_r[1:4][::-1])
  1618. assert_equal(y.dtype, dt)
  1619. def test_rank1_same(self, dt):
  1620. a, b, y_r = self._setup_rank1(dt)
  1621. y = correlate(a, b, 'same')
  1622. assert_array_almost_equal(y, y_r[:-1])
  1623. assert_equal(y.dtype, dt)
  1624. def test_rank1_full(self, dt):
  1625. a, b, y_r = self._setup_rank1(dt)
  1626. y = correlate(a, b, 'full')
  1627. assert_array_almost_equal(y, y_r)
  1628. assert_equal(y.dtype, dt)
  1629. def _setup_rank3(self, dt):
  1630. a = np.linspace(0, 39, 40).reshape((2, 4, 5), order='F').astype(
  1631. dt)
  1632. b = np.linspace(0, 23, 24).reshape((2, 3, 4), order='F').astype(
  1633. dt)
  1634. y_r = array([[[0., 184., 504., 912., 1360., 888., 472., 160.],
  1635. [46., 432., 1062., 1840., 2672., 1698., 864., 266.],
  1636. [134., 736., 1662., 2768., 3920., 2418., 1168., 314.],
  1637. [260., 952., 1932., 3056., 4208., 2580., 1240., 332.],
  1638. [202., 664., 1290., 1984., 2688., 1590., 712., 150.],
  1639. [114., 344., 642., 960., 1280., 726., 296., 38.]],
  1640. [[23., 400., 1035., 1832., 2696., 1737., 904., 293.],
  1641. [134., 920., 2166., 3680., 5280., 3306., 1640., 474.],
  1642. [325., 1544., 3369., 5512., 7720., 4683., 2192., 535.],
  1643. [571., 1964., 3891., 6064., 8272., 4989., 2324., 565.],
  1644. [434., 1360., 2586., 3920., 5264., 3054., 1312., 230.],
  1645. [241., 700., 1281., 1888., 2496., 1383., 532., 39.]],
  1646. [[22., 214., 528., 916., 1332., 846., 430., 132.],
  1647. [86., 484., 1098., 1832., 2600., 1602., 772., 206.],
  1648. [188., 802., 1698., 2732., 3788., 2256., 1018., 218.],
  1649. [308., 1006., 1950., 2996., 4052., 2400., 1078., 230.],
  1650. [230., 692., 1290., 1928., 2568., 1458., 596., 78.],
  1651. [126., 354., 636., 924., 1212., 654., 234., 0.]]],
  1652. dtype=dt)
  1653. return a, b, y_r
  1654. def test_rank3_valid(self, dt):
  1655. a, b, y_r = self._setup_rank3(dt)
  1656. y = correlate(a, b, "valid")
  1657. assert_array_almost_equal(y, y_r[1:2, 2:4, 3:5])
  1658. assert_equal(y.dtype, dt)
  1659. # See gh-5897
  1660. y = correlate(b, a, "valid")
  1661. assert_array_almost_equal(y, y_r[1:2, 2:4, 3:5][::-1, ::-1, ::-1])
  1662. assert_equal(y.dtype, dt)
  1663. def test_rank3_same(self, dt):
  1664. a, b, y_r = self._setup_rank3(dt)
  1665. y = correlate(a, b, "same")
  1666. assert_array_almost_equal(y, y_r[0:-1, 1:-1, 1:-2])
  1667. assert_equal(y.dtype, dt)
  1668. def test_rank3_all(self, dt):
  1669. a, b, y_r = self._setup_rank3(dt)
  1670. y = correlate(a, b)
  1671. assert_array_almost_equal(y, y_r)
  1672. assert_equal(y.dtype, dt)
  1673. class TestCorrelate:
  1674. # Tests that don't depend on dtype
  1675. def test_invalid_shapes(self):
  1676. # By "invalid," we mean that no one
  1677. # array has dimensions that are all at
  1678. # least as large as the corresponding
  1679. # dimensions of the other array. This
  1680. # setup should throw a ValueError.
  1681. a = np.arange(1, 7).reshape((2, 3))
  1682. b = np.arange(-6, 0).reshape((3, 2))
  1683. assert_raises(ValueError, correlate, *(a, b), **{'mode': 'valid'})
  1684. assert_raises(ValueError, correlate, *(b, a), **{'mode': 'valid'})
  1685. def test_invalid_params(self):
  1686. a = [3, 4, 5]
  1687. b = [1, 2, 3]
  1688. assert_raises(ValueError, correlate, a, b, mode='spam')
  1689. assert_raises(ValueError, correlate, a, b, mode='eggs', method='fft')
  1690. assert_raises(ValueError, correlate, a, b, mode='ham', method='direct')
  1691. assert_raises(ValueError, correlate, a, b, mode='full', method='bacon')
  1692. assert_raises(ValueError, correlate, a, b, mode='same', method='bacon')
  1693. def test_mismatched_dims(self):
  1694. # Input arrays should have the same number of dimensions
  1695. assert_raises(ValueError, correlate, [1], 2, method='direct')
  1696. assert_raises(ValueError, correlate, 1, [2], method='direct')
  1697. assert_raises(ValueError, correlate, [1], 2, method='fft')
  1698. assert_raises(ValueError, correlate, 1, [2], method='fft')
  1699. assert_raises(ValueError, correlate, [1], [[2]])
  1700. assert_raises(ValueError, correlate, [3], 2)
  1701. def test_numpy_fastpath(self):
  1702. a = [1, 2, 3]
  1703. b = [4, 5]
  1704. assert_allclose(correlate(a, b, mode='same'), [5, 14, 23])
  1705. a = [1, 2, 3]
  1706. b = [4, 5, 6]
  1707. assert_allclose(correlate(a, b, mode='same'), [17, 32, 23])
  1708. assert_allclose(correlate(a, b, mode='full'), [6, 17, 32, 23, 12])
  1709. assert_allclose(correlate(a, b, mode='valid'), [32])
  1710. @pytest.mark.parametrize("mode", ["valid", "same", "full"])
  1711. @pytest.mark.parametrize("behind", [True, False])
  1712. @pytest.mark.parametrize("input_size", [100, 101, 1000, 1001, 10000, 10001])
  1713. def test_correlation_lags(mode, behind, input_size):
  1714. # generate random data
  1715. rng = np.random.RandomState(0)
  1716. in1 = rng.standard_normal(input_size)
  1717. offset = int(input_size/10)
  1718. # generate offset version of array to correlate with
  1719. if behind:
  1720. # y is behind x
  1721. in2 = np.concatenate([rng.standard_normal(offset), in1])
  1722. expected = -offset
  1723. else:
  1724. # y is ahead of x
  1725. in2 = in1[offset:]
  1726. expected = offset
  1727. # cross correlate, returning lag information
  1728. correlation = correlate(in1, in2, mode=mode)
  1729. lags = correlation_lags(in1.size, in2.size, mode=mode)
  1730. # identify the peak
  1731. lag_index = np.argmax(correlation)
  1732. # Check as expected
  1733. assert_equal(lags[lag_index], expected)
  1734. # Correlation and lags shape should match
  1735. assert_equal(lags.shape, correlation.shape)
  1736. @pytest.mark.parametrize('dt', [np.csingle, np.cdouble, np.clongdouble])
  1737. class TestCorrelateComplex:
  1738. # The decimal precision to be used for comparing results.
  1739. # This value will be passed as the 'decimal' keyword argument of
  1740. # assert_array_almost_equal().
  1741. # Since correlate may chose to use FFT method which converts
  1742. # longdoubles to doubles internally don't expect better precision
  1743. # for longdouble than for double (see gh-9520).
  1744. def decimal(self, dt):
  1745. if dt == np.clongdouble:
  1746. dt = np.cdouble
  1747. return int(2 * np.finfo(dt).precision / 3)
  1748. def _setup_rank1(self, dt, mode):
  1749. np.random.seed(9)
  1750. a = np.random.randn(10).astype(dt)
  1751. a += 1j * np.random.randn(10).astype(dt)
  1752. b = np.random.randn(8).astype(dt)
  1753. b += 1j * np.random.randn(8).astype(dt)
  1754. y_r = (correlate(a.real, b.real, mode=mode) +
  1755. correlate(a.imag, b.imag, mode=mode)).astype(dt)
  1756. y_r += 1j * (-correlate(a.real, b.imag, mode=mode) +
  1757. correlate(a.imag, b.real, mode=mode))
  1758. return a, b, y_r
  1759. def test_rank1_valid(self, dt):
  1760. a, b, y_r = self._setup_rank1(dt, 'valid')
  1761. y = correlate(a, b, 'valid')
  1762. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt))
  1763. assert_equal(y.dtype, dt)
  1764. # See gh-5897
  1765. y = correlate(b, a, 'valid')
  1766. assert_array_almost_equal(y, y_r[::-1].conj(), decimal=self.decimal(dt))
  1767. assert_equal(y.dtype, dt)
  1768. def test_rank1_same(self, dt):
  1769. a, b, y_r = self._setup_rank1(dt, 'same')
  1770. y = correlate(a, b, 'same')
  1771. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt))
  1772. assert_equal(y.dtype, dt)
  1773. def test_rank1_full(self, dt):
  1774. a, b, y_r = self._setup_rank1(dt, 'full')
  1775. y = correlate(a, b, 'full')
  1776. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt))
  1777. assert_equal(y.dtype, dt)
  1778. def test_swap_full(self, dt):
  1779. d = np.array([0.+0.j, 1.+1.j, 2.+2.j], dtype=dt)
  1780. k = np.array([1.+3.j, 2.+4.j, 3.+5.j, 4.+6.j], dtype=dt)
  1781. y = correlate(d, k)
  1782. assert_equal(y, [0.+0.j, 10.-2.j, 28.-6.j, 22.-6.j, 16.-6.j, 8.-4.j])
  1783. def test_swap_same(self, dt):
  1784. d = [0.+0.j, 1.+1.j, 2.+2.j]
  1785. k = [1.+3.j, 2.+4.j, 3.+5.j, 4.+6.j]
  1786. y = correlate(d, k, mode="same")
  1787. assert_equal(y, [10.-2.j, 28.-6.j, 22.-6.j])
  1788. def test_rank3(self, dt):
  1789. a = np.random.randn(10, 8, 6).astype(dt)
  1790. a += 1j * np.random.randn(10, 8, 6).astype(dt)
  1791. b = np.random.randn(8, 6, 4).astype(dt)
  1792. b += 1j * np.random.randn(8, 6, 4).astype(dt)
  1793. y_r = (correlate(a.real, b.real)
  1794. + correlate(a.imag, b.imag)).astype(dt)
  1795. y_r += 1j * (-correlate(a.real, b.imag) + correlate(a.imag, b.real))
  1796. y = correlate(a, b, 'full')
  1797. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt) - 1)
  1798. assert_equal(y.dtype, dt)
  1799. def test_rank0(self, dt):
  1800. a = np.array(np.random.randn()).astype(dt)
  1801. a += 1j * np.array(np.random.randn()).astype(dt)
  1802. b = np.array(np.random.randn()).astype(dt)
  1803. b += 1j * np.array(np.random.randn()).astype(dt)
  1804. y_r = (correlate(a.real, b.real)
  1805. + correlate(a.imag, b.imag)).astype(dt)
  1806. y_r += 1j * np.array(-correlate(a.real, b.imag) +
  1807. correlate(a.imag, b.real))
  1808. y = correlate(a, b, 'full')
  1809. assert_array_almost_equal(y, y_r, decimal=self.decimal(dt) - 1)
  1810. assert_equal(y.dtype, dt)
  1811. assert_equal(correlate([1], [2j]), correlate(1, 2j))
  1812. assert_equal(correlate([2j], [3j]), correlate(2j, 3j))
  1813. assert_equal(correlate([3j], [4]), correlate(3j, 4))
  1814. class TestCorrelate2d:
  1815. def test_consistency_correlate_funcs(self):
  1816. # Compare np.correlate, signal.correlate, signal.correlate2d
  1817. a = np.arange(5)
  1818. b = np.array([3.2, 1.4, 3])
  1819. for mode in ['full', 'valid', 'same']:
  1820. assert_almost_equal(np.correlate(a, b, mode=mode),
  1821. signal.correlate(a, b, mode=mode))
  1822. assert_almost_equal(np.squeeze(signal.correlate2d([a], [b],
  1823. mode=mode)),
  1824. signal.correlate(a, b, mode=mode))
  1825. # See gh-5897
  1826. if mode == 'valid':
  1827. assert_almost_equal(np.correlate(b, a, mode=mode),
  1828. signal.correlate(b, a, mode=mode))
  1829. assert_almost_equal(np.squeeze(signal.correlate2d([b], [a],
  1830. mode=mode)),
  1831. signal.correlate(b, a, mode=mode))
  1832. def test_invalid_shapes(self):
  1833. # By "invalid," we mean that no one
  1834. # array has dimensions that are all at
  1835. # least as large as the corresponding
  1836. # dimensions of the other array. This
  1837. # setup should throw a ValueError.
  1838. a = np.arange(1, 7).reshape((2, 3))
  1839. b = np.arange(-6, 0).reshape((3, 2))
  1840. assert_raises(ValueError, signal.correlate2d, *(a, b), **{'mode': 'valid'})
  1841. assert_raises(ValueError, signal.correlate2d, *(b, a), **{'mode': 'valid'})
  1842. def test_complex_input(self):
  1843. assert_equal(signal.correlate2d([[1]], [[2j]]), -2j)
  1844. assert_equal(signal.correlate2d([[2j]], [[3j]]), 6)
  1845. assert_equal(signal.correlate2d([[3j]], [[4]]), 12j)
  1846. class TestLFilterZI:
  1847. def test_basic(self):
  1848. a = np.array([1.0, -1.0, 0.5])
  1849. b = np.array([1.0, 0.0, 2.0])
  1850. zi_expected = np.array([5.0, -1.0])
  1851. zi = lfilter_zi(b, a)
  1852. assert_array_almost_equal(zi, zi_expected)
  1853. def test_scale_invariance(self):
  1854. # Regression test. There was a bug in which b was not correctly
  1855. # rescaled when a[0] was nonzero.
  1856. b = np.array([2, 8, 5])
  1857. a = np.array([1, 1, 8])
  1858. zi1 = lfilter_zi(b, a)
  1859. zi2 = lfilter_zi(2*b, 2*a)
  1860. assert_allclose(zi2, zi1, rtol=1e-12)
  1861. @pytest.mark.parametrize('dtype', [np.float32, np.float64])
  1862. def test_types(self, dtype):
  1863. b = np.zeros((8), dtype=dtype)
  1864. a = np.array([1], dtype=dtype)
  1865. assert_equal(np.real(signal.lfilter_zi(b, a)).dtype, dtype)
  1866. class TestFiltFilt:
  1867. filtfilt_kind = 'tf'
  1868. def filtfilt(self, zpk, x, axis=-1, padtype='odd', padlen=None,
  1869. method='pad', irlen=None):
  1870. if self.filtfilt_kind == 'tf':
  1871. b, a = zpk2tf(*zpk)
  1872. return filtfilt(b, a, x, axis, padtype, padlen, method, irlen)
  1873. elif self.filtfilt_kind == 'sos':
  1874. sos = zpk2sos(*zpk)
  1875. return sosfiltfilt(sos, x, axis, padtype, padlen)
  1876. def test_basic(self):
  1877. zpk = tf2zpk([1, 2, 3], [1, 2, 3])
  1878. out = self.filtfilt(zpk, np.arange(12))
  1879. assert_allclose(out, arange(12), atol=5.28e-11)
  1880. def test_sine(self):
  1881. rate = 2000
  1882. t = np.linspace(0, 1.0, rate + 1)
  1883. # A signal with low frequency and a high frequency.
  1884. xlow = np.sin(5 * 2 * np.pi * t)
  1885. xhigh = np.sin(250 * 2 * np.pi * t)
  1886. x = xlow + xhigh
  1887. zpk = butter(8, 0.125, output='zpk')
  1888. # r is the magnitude of the largest pole.
  1889. r = np.abs(zpk[1]).max()
  1890. eps = 1e-5
  1891. # n estimates the number of steps for the
  1892. # transient to decay by a factor of eps.
  1893. n = int(np.ceil(np.log(eps) / np.log(r)))
  1894. # High order lowpass filter...
  1895. y = self.filtfilt(zpk, x, padlen=n)
  1896. # Result should be just xlow.
  1897. err = np.abs(y - xlow).max()
  1898. assert_(err < 1e-4)
  1899. # A 2D case.
  1900. x2d = np.vstack([xlow, xlow + xhigh])
  1901. y2d = self.filtfilt(zpk, x2d, padlen=n, axis=1)
  1902. assert_equal(y2d.shape, x2d.shape)
  1903. err = np.abs(y2d - xlow).max()
  1904. assert_(err < 1e-4)
  1905. # Use the previous result to check the use of the axis keyword.
  1906. # (Regression test for ticket #1620)
  1907. y2dt = self.filtfilt(zpk, x2d.T, padlen=n, axis=0)
  1908. assert_equal(y2d, y2dt.T)
  1909. def test_axis(self):
  1910. # Test the 'axis' keyword on a 3D array.
  1911. x = np.arange(10.0 * 11.0 * 12.0).reshape(10, 11, 12)
  1912. zpk = butter(3, 0.125, output='zpk')
  1913. y0 = self.filtfilt(zpk, x, padlen=0, axis=0)
  1914. y1 = self.filtfilt(zpk, np.swapaxes(x, 0, 1), padlen=0, axis=1)
  1915. assert_array_equal(y0, np.swapaxes(y1, 0, 1))
  1916. y2 = self.filtfilt(zpk, np.swapaxes(x, 0, 2), padlen=0, axis=2)
  1917. assert_array_equal(y0, np.swapaxes(y2, 0, 2))
  1918. def test_acoeff(self):
  1919. if self.filtfilt_kind != 'tf':
  1920. return # only necessary for TF
  1921. # test for 'a' coefficient as single number
  1922. out = signal.filtfilt([.5, .5], 1, np.arange(10))
  1923. assert_allclose(out, np.arange(10), rtol=1e-14, atol=1e-14)
  1924. def test_gust_simple(self):
  1925. if self.filtfilt_kind != 'tf':
  1926. pytest.skip('gust only implemented for TF systems')
  1927. # The input array has length 2. The exact solution for this case
  1928. # was computed "by hand".
  1929. x = np.array([1.0, 2.0])
  1930. b = np.array([0.5])
  1931. a = np.array([1.0, -0.5])
  1932. y, z1, z2 = _filtfilt_gust(b, a, x)
  1933. assert_allclose([z1[0], z2[0]],
  1934. [0.3*x[0] + 0.2*x[1], 0.2*x[0] + 0.3*x[1]])
  1935. assert_allclose(y, [z1[0] + 0.25*z2[0] + 0.25*x[0] + 0.125*x[1],
  1936. 0.25*z1[0] + z2[0] + 0.125*x[0] + 0.25*x[1]])
  1937. def test_gust_scalars(self):
  1938. if self.filtfilt_kind != 'tf':
  1939. pytest.skip('gust only implemented for TF systems')
  1940. # The filter coefficients are both scalars, so the filter simply
  1941. # multiplies its input by b/a. When it is used in filtfilt, the
  1942. # factor is (b/a)**2.
  1943. x = np.arange(12)
  1944. b = 3.0
  1945. a = 2.0
  1946. y = filtfilt(b, a, x, method="gust")
  1947. expected = (b/a)**2 * x
  1948. assert_allclose(y, expected)
  1949. class TestSOSFiltFilt(TestFiltFilt):
  1950. filtfilt_kind = 'sos'
  1951. def test_equivalence(self):
  1952. """Test equivalence between sosfiltfilt and filtfilt"""
  1953. x = np.random.RandomState(0).randn(1000)
  1954. for order in range(1, 6):
  1955. zpk = signal.butter(order, 0.35, output='zpk')
  1956. b, a = zpk2tf(*zpk)
  1957. sos = zpk2sos(*zpk)
  1958. y = filtfilt(b, a, x)
  1959. y_sos = sosfiltfilt(sos, x)
  1960. assert_allclose(y, y_sos, atol=1e-12, err_msg='order=%s' % order)
  1961. def filtfilt_gust_opt(b, a, x):
  1962. """
  1963. An alternative implementation of filtfilt with Gustafsson edges.
  1964. This function computes the same result as
  1965. `scipy.signal._signaltools._filtfilt_gust`, but only 1-d arrays
  1966. are accepted. The problem is solved using `fmin` from `scipy.optimize`.
  1967. `_filtfilt_gust` is significanly faster than this implementation.
  1968. """
  1969. def filtfilt_gust_opt_func(ics, b, a, x):
  1970. """Objective function used in filtfilt_gust_opt."""
  1971. m = max(len(a), len(b)) - 1
  1972. z0f = ics[:m]
  1973. z0b = ics[m:]
  1974. y_f = lfilter(b, a, x, zi=z0f)[0]
  1975. y_fb = lfilter(b, a, y_f[::-1], zi=z0b)[0][::-1]
  1976. y_b = lfilter(b, a, x[::-1], zi=z0b)[0][::-1]
  1977. y_bf = lfilter(b, a, y_b, zi=z0f)[0]
  1978. value = np.sum((y_fb - y_bf)**2)
  1979. return value
  1980. m = max(len(a), len(b)) - 1
  1981. zi = lfilter_zi(b, a)
  1982. ics = np.concatenate((x[:m].mean()*zi, x[-m:].mean()*zi))
  1983. result = fmin(filtfilt_gust_opt_func, ics, args=(b, a, x),
  1984. xtol=1e-10, ftol=1e-12,
  1985. maxfun=10000, maxiter=10000,
  1986. full_output=True, disp=False)
  1987. opt, fopt, niter, funcalls, warnflag = result
  1988. if warnflag > 0:
  1989. raise RuntimeError("minimization failed in filtfilt_gust_opt: "
  1990. "warnflag=%d" % warnflag)
  1991. z0f = opt[:m]
  1992. z0b = opt[m:]
  1993. # Apply the forward-backward filter using the computed initial
  1994. # conditions.
  1995. y_b = lfilter(b, a, x[::-1], zi=z0b)[0][::-1]
  1996. y = lfilter(b, a, y_b, zi=z0f)[0]
  1997. return y, z0f, z0b
  1998. def check_filtfilt_gust(b, a, shape, axis, irlen=None):
  1999. # Generate x, the data to be filtered.
  2000. np.random.seed(123)
  2001. x = np.random.randn(*shape)
  2002. # Apply filtfilt to x. This is the main calculation to be checked.
  2003. y = filtfilt(b, a, x, axis=axis, method="gust", irlen=irlen)
  2004. # Also call the private function so we can test the ICs.
  2005. yg, zg1, zg2 = _filtfilt_gust(b, a, x, axis=axis, irlen=irlen)
  2006. # filtfilt_gust_opt is an independent implementation that gives the
  2007. # expected result, but it only handles 1-D arrays, so use some looping
  2008. # and reshaping shenanigans to create the expected output arrays.
  2009. xx = np.swapaxes(x, axis, -1)
  2010. out_shape = xx.shape[:-1]
  2011. yo = np.empty_like(xx)
  2012. m = max(len(a), len(b)) - 1
  2013. zo1 = np.empty(out_shape + (m,))
  2014. zo2 = np.empty(out_shape + (m,))
  2015. for indx in product(*[range(d) for d in out_shape]):
  2016. yo[indx], zo1[indx], zo2[indx] = filtfilt_gust_opt(b, a, xx[indx])
  2017. yo = np.swapaxes(yo, -1, axis)
  2018. zo1 = np.swapaxes(zo1, -1, axis)
  2019. zo2 = np.swapaxes(zo2, -1, axis)
  2020. assert_allclose(y, yo, rtol=1e-8, atol=1e-9)
  2021. assert_allclose(yg, yo, rtol=1e-8, atol=1e-9)
  2022. assert_allclose(zg1, zo1, rtol=1e-8, atol=1e-9)
  2023. assert_allclose(zg2, zo2, rtol=1e-8, atol=1e-9)
  2024. def test_choose_conv_method():
  2025. for mode in ['valid', 'same', 'full']:
  2026. for ndim in [1, 2]:
  2027. n, k, true_method = 8, 6, 'direct'
  2028. x = np.random.randn(*((n,) * ndim))
  2029. h = np.random.randn(*((k,) * ndim))
  2030. method = choose_conv_method(x, h, mode=mode)
  2031. assert_equal(method, true_method)
  2032. method_try, times = choose_conv_method(x, h, mode=mode, measure=True)
  2033. assert_(method_try in {'fft', 'direct'})
  2034. assert_(type(times) is dict)
  2035. assert_('fft' in times.keys() and 'direct' in times.keys())
  2036. n = 10
  2037. for not_fft_conv_supp in ["complex256", "complex192"]:
  2038. if hasattr(np, not_fft_conv_supp):
  2039. x = np.ones(n, dtype=not_fft_conv_supp)
  2040. h = x.copy()
  2041. assert_equal(choose_conv_method(x, h, mode=mode), 'direct')
  2042. x = np.array([2**51], dtype=np.int64)
  2043. h = x.copy()
  2044. assert_equal(choose_conv_method(x, h, mode=mode), 'direct')
  2045. x = [Decimal(3), Decimal(2)]
  2046. h = [Decimal(1), Decimal(4)]
  2047. assert_equal(choose_conv_method(x, h, mode=mode), 'direct')
  2048. def test_filtfilt_gust():
  2049. # Design a filter.
  2050. z, p, k = signal.ellip(3, 0.01, 120, 0.0875, output='zpk')
  2051. # Find the approximate impulse response length of the filter.
  2052. eps = 1e-10
  2053. r = np.max(np.abs(p))
  2054. approx_impulse_len = int(np.ceil(np.log(eps) / np.log(r)))
  2055. np.random.seed(123)
  2056. b, a = zpk2tf(z, p, k)
  2057. for irlen in [None, approx_impulse_len]:
  2058. signal_len = 5 * approx_impulse_len
  2059. # 1-d test case
  2060. check_filtfilt_gust(b, a, (signal_len,), 0, irlen)
  2061. # 3-d test case; test each axis.
  2062. for axis in range(3):
  2063. shape = [2, 2, 2]
  2064. shape[axis] = signal_len
  2065. check_filtfilt_gust(b, a, shape, axis, irlen)
  2066. # Test case with length less than 2*approx_impulse_len.
  2067. # In this case, `filtfilt_gust` should behave the same as if
  2068. # `irlen=None` was given.
  2069. length = 2*approx_impulse_len - 50
  2070. check_filtfilt_gust(b, a, (length,), 0, approx_impulse_len)
  2071. class TestDecimate:
  2072. def test_bad_args(self):
  2073. x = np.arange(12)
  2074. assert_raises(TypeError, signal.decimate, x, q=0.5, n=1)
  2075. assert_raises(TypeError, signal.decimate, x, q=2, n=0.5)
  2076. def test_basic_IIR(self):
  2077. x = np.arange(12)
  2078. y = signal.decimate(x, 2, n=1, ftype='iir', zero_phase=False).round()
  2079. assert_array_equal(y, x[::2])
  2080. def test_basic_FIR(self):
  2081. x = np.arange(12)
  2082. y = signal.decimate(x, 2, n=1, ftype='fir', zero_phase=False).round()
  2083. assert_array_equal(y, x[::2])
  2084. def test_shape(self):
  2085. # Regression test for ticket #1480.
  2086. z = np.zeros((30, 30))
  2087. d0 = signal.decimate(z, 2, axis=0, zero_phase=False)
  2088. assert_equal(d0.shape, (15, 30))
  2089. d1 = signal.decimate(z, 2, axis=1, zero_phase=False)
  2090. assert_equal(d1.shape, (30, 15))
  2091. def test_phaseshift_FIR(self):
  2092. with suppress_warnings() as sup:
  2093. sup.filter(BadCoefficients, "Badly conditioned filter")
  2094. self._test_phaseshift(method='fir', zero_phase=False)
  2095. def test_zero_phase_FIR(self):
  2096. with suppress_warnings() as sup:
  2097. sup.filter(BadCoefficients, "Badly conditioned filter")
  2098. self._test_phaseshift(method='fir', zero_phase=True)
  2099. def test_phaseshift_IIR(self):
  2100. self._test_phaseshift(method='iir', zero_phase=False)
  2101. def test_zero_phase_IIR(self):
  2102. self._test_phaseshift(method='iir', zero_phase=True)
  2103. def _test_phaseshift(self, method, zero_phase):
  2104. rate = 120
  2105. rates_to = [15, 20, 30, 40] # q = 8, 6, 4, 3
  2106. t_tot = int(100) # Need to let antialiasing filters settle
  2107. t = np.arange(rate*t_tot+1) / float(rate)
  2108. # Sinusoids at 0.8*nyquist, windowed to avoid edge artifacts
  2109. freqs = np.array(rates_to) * 0.8 / 2
  2110. d = (np.exp(1j * 2 * np.pi * freqs[:, np.newaxis] * t)
  2111. * signal.windows.tukey(t.size, 0.1))
  2112. for rate_to in rates_to:
  2113. q = rate // rate_to
  2114. t_to = np.arange(rate_to*t_tot+1) / float(rate_to)
  2115. d_tos = (np.exp(1j * 2 * np.pi * freqs[:, np.newaxis] * t_to)
  2116. * signal.windows.tukey(t_to.size, 0.1))
  2117. # Set up downsampling filters, match v0.17 defaults
  2118. if method == 'fir':
  2119. n = 30
  2120. system = signal.dlti(signal.firwin(n + 1, 1. / q,
  2121. window='hamming'), 1.)
  2122. elif method == 'iir':
  2123. n = 8
  2124. wc = 0.8*np.pi/q
  2125. system = signal.dlti(*signal.cheby1(n, 0.05, wc/np.pi))
  2126. # Calculate expected phase response, as unit complex vector
  2127. if zero_phase is False:
  2128. _, h_resps = signal.freqz(system.num, system.den,
  2129. freqs/rate*2*np.pi)
  2130. h_resps /= np.abs(h_resps)
  2131. else:
  2132. h_resps = np.ones_like(freqs)
  2133. y_resamps = signal.decimate(d.real, q, n, ftype=system,
  2134. zero_phase=zero_phase)
  2135. # Get phase from complex inner product, like CSD
  2136. h_resamps = np.sum(d_tos.conj() * y_resamps, axis=-1)
  2137. h_resamps /= np.abs(h_resamps)
  2138. subnyq = freqs < 0.5*rate_to
  2139. # Complex vectors should be aligned, only compare below nyquist
  2140. assert_allclose(np.angle(h_resps.conj()*h_resamps)[subnyq], 0,
  2141. atol=1e-3, rtol=1e-3)
  2142. def test_auto_n(self):
  2143. # Test that our value of n is a reasonable choice (depends on
  2144. # the downsampling factor)
  2145. sfreq = 100.
  2146. n = 1000
  2147. t = np.arange(n) / sfreq
  2148. # will alias for decimations (>= 15)
  2149. x = np.sqrt(2. / n) * np.sin(2 * np.pi * (sfreq / 30.) * t)
  2150. assert_allclose(np.linalg.norm(x), 1., rtol=1e-3)
  2151. x_out = signal.decimate(x, 30, ftype='fir')
  2152. assert_array_less(np.linalg.norm(x_out), 0.01)
  2153. def test_long_float32(self):
  2154. # regression: gh-15072. With 32-bit float and either lfilter
  2155. # or filtfilt, this is numerically unstable
  2156. x = signal.decimate(np.ones(10_000, dtype=np.float32), 10)
  2157. assert not any(np.isnan(x))
  2158. def test_float16_upcast(self):
  2159. # float16 must be upcast to float64
  2160. x = signal.decimate(np.ones(100, dtype=np.float16), 10)
  2161. assert x.dtype.type == np.float64
  2162. class TestHilbert:
  2163. def test_bad_args(self):
  2164. x = np.array([1.0 + 0.0j])
  2165. assert_raises(ValueError, hilbert, x)
  2166. x = np.arange(8.0)
  2167. assert_raises(ValueError, hilbert, x, N=0)
  2168. def test_hilbert_theoretical(self):
  2169. # test cases by Ariel Rokem
  2170. decimal = 14
  2171. pi = np.pi
  2172. t = np.arange(0, 2 * pi, pi / 256)
  2173. a0 = np.sin(t)
  2174. a1 = np.cos(t)
  2175. a2 = np.sin(2 * t)
  2176. a3 = np.cos(2 * t)
  2177. a = np.vstack([a0, a1, a2, a3])
  2178. h = hilbert(a)
  2179. h_abs = np.abs(h)
  2180. h_angle = np.angle(h)
  2181. h_real = np.real(h)
  2182. # The real part should be equal to the original signals:
  2183. assert_almost_equal(h_real, a, decimal)
  2184. # The absolute value should be one everywhere, for this input:
  2185. assert_almost_equal(h_abs, np.ones(a.shape), decimal)
  2186. # For the 'slow' sine - the phase should go from -pi/2 to pi/2 in
  2187. # the first 256 bins:
  2188. assert_almost_equal(h_angle[0, :256],
  2189. np.arange(-pi / 2, pi / 2, pi / 256),
  2190. decimal)
  2191. # For the 'slow' cosine - the phase should go from 0 to pi in the
  2192. # same interval:
  2193. assert_almost_equal(
  2194. h_angle[1, :256], np.arange(0, pi, pi / 256), decimal)
  2195. # The 'fast' sine should make this phase transition in half the time:
  2196. assert_almost_equal(h_angle[2, :128],
  2197. np.arange(-pi / 2, pi / 2, pi / 128),
  2198. decimal)
  2199. # Ditto for the 'fast' cosine:
  2200. assert_almost_equal(
  2201. h_angle[3, :128], np.arange(0, pi, pi / 128), decimal)
  2202. # The imaginary part of hilbert(cos(t)) = sin(t) Wikipedia
  2203. assert_almost_equal(h[1].imag, a0, decimal)
  2204. def test_hilbert_axisN(self):
  2205. # tests for axis and N arguments
  2206. a = np.arange(18).reshape(3, 6)
  2207. # test axis
  2208. aa = hilbert(a, axis=-1)
  2209. assert_equal(hilbert(a.T, axis=0), aa.T)
  2210. # test 1d
  2211. assert_almost_equal(hilbert(a[0]), aa[0], 14)
  2212. # test N
  2213. aan = hilbert(a, N=20, axis=-1)
  2214. assert_equal(aan.shape, [3, 20])
  2215. assert_equal(hilbert(a.T, N=20, axis=0).shape, [20, 3])
  2216. # the next test is just a regression test,
  2217. # no idea whether numbers make sense
  2218. a0hilb = np.array([0.000000000000000e+00 - 1.72015830311905j,
  2219. 1.000000000000000e+00 - 2.047794505137069j,
  2220. 1.999999999999999e+00 - 2.244055555687583j,
  2221. 3.000000000000000e+00 - 1.262750302935009j,
  2222. 4.000000000000000e+00 - 1.066489252384493j,
  2223. 5.000000000000000e+00 + 2.918022706971047j,
  2224. 8.881784197001253e-17 + 3.845658908989067j,
  2225. -9.444121133484362e-17 + 0.985044202202061j,
  2226. -1.776356839400251e-16 + 1.332257797702019j,
  2227. -3.996802888650564e-16 + 0.501905089898885j,
  2228. 1.332267629550188e-16 + 0.668696078880782j,
  2229. -1.192678053963799e-16 + 0.235487067862679j,
  2230. -1.776356839400251e-16 + 0.286439612812121j,
  2231. 3.108624468950438e-16 + 0.031676888064907j,
  2232. 1.332267629550188e-16 - 0.019275656884536j,
  2233. -2.360035624836702e-16 - 0.1652588660287j,
  2234. 0.000000000000000e+00 - 0.332049855010597j,
  2235. 3.552713678800501e-16 - 0.403810179797771j,
  2236. 8.881784197001253e-17 - 0.751023775297729j,
  2237. 9.444121133484362e-17 - 0.79252210110103j])
  2238. assert_almost_equal(aan[0], a0hilb, 14, 'N regression')
  2239. @pytest.mark.parametrize('dtype', [np.float32, np.float64])
  2240. def test_hilbert_types(self, dtype):
  2241. in_typed = np.zeros(8, dtype=dtype)
  2242. assert_equal(np.real(signal.hilbert(in_typed)).dtype, dtype)
  2243. class TestHilbert2:
  2244. def test_bad_args(self):
  2245. # x must be real.
  2246. x = np.array([[1.0 + 0.0j]])
  2247. assert_raises(ValueError, hilbert2, x)
  2248. # x must be rank 2.
  2249. x = np.arange(24).reshape(2, 3, 4)
  2250. assert_raises(ValueError, hilbert2, x)
  2251. # Bad value for N.
  2252. x = np.arange(16).reshape(4, 4)
  2253. assert_raises(ValueError, hilbert2, x, N=0)
  2254. assert_raises(ValueError, hilbert2, x, N=(2, 0))
  2255. assert_raises(ValueError, hilbert2, x, N=(2,))
  2256. @pytest.mark.parametrize('dtype', [np.float32, np.float64])
  2257. def test_hilbert2_types(self, dtype):
  2258. in_typed = np.zeros((2, 32), dtype=dtype)
  2259. assert_equal(np.real(signal.hilbert2(in_typed)).dtype, dtype)
  2260. class TestPartialFractionExpansion:
  2261. @staticmethod
  2262. def assert_rp_almost_equal(r, p, r_true, p_true, decimal=7):
  2263. r_true = np.asarray(r_true)
  2264. p_true = np.asarray(p_true)
  2265. distance = np.hypot(abs(p[:, None] - p_true),
  2266. abs(r[:, None] - r_true))
  2267. rows, cols = linear_sum_assignment(distance)
  2268. assert_almost_equal(p[rows], p_true[cols], decimal=decimal)
  2269. assert_almost_equal(r[rows], r_true[cols], decimal=decimal)
  2270. def test_compute_factors(self):
  2271. factors, poly = _compute_factors([1, 2, 3], [3, 2, 1])
  2272. assert_equal(len(factors), 3)
  2273. assert_almost_equal(factors[0], np.poly([2, 2, 3]))
  2274. assert_almost_equal(factors[1], np.poly([1, 1, 1, 3]))
  2275. assert_almost_equal(factors[2], np.poly([1, 1, 1, 2, 2]))
  2276. assert_almost_equal(poly, np.poly([1, 1, 1, 2, 2, 3]))
  2277. factors, poly = _compute_factors([1, 2, 3], [3, 2, 1],
  2278. include_powers=True)
  2279. assert_equal(len(factors), 6)
  2280. assert_almost_equal(factors[0], np.poly([1, 1, 2, 2, 3]))
  2281. assert_almost_equal(factors[1], np.poly([1, 2, 2, 3]))
  2282. assert_almost_equal(factors[2], np.poly([2, 2, 3]))
  2283. assert_almost_equal(factors[3], np.poly([1, 1, 1, 2, 3]))
  2284. assert_almost_equal(factors[4], np.poly([1, 1, 1, 3]))
  2285. assert_almost_equal(factors[5], np.poly([1, 1, 1, 2, 2]))
  2286. assert_almost_equal(poly, np.poly([1, 1, 1, 2, 2, 3]))
  2287. def test_group_poles(self):
  2288. unique, multiplicity = _group_poles(
  2289. [1.0, 1.001, 1.003, 2.0, 2.003, 3.0], 0.1, 'min')
  2290. assert_equal(unique, [1.0, 2.0, 3.0])
  2291. assert_equal(multiplicity, [3, 2, 1])
  2292. def test_residue_general(self):
  2293. # Test are taken from issue #4464, note that poles in scipy are
  2294. # in increasing by absolute value order, opposite to MATLAB.
  2295. r, p, k = residue([5, 3, -2, 7], [-4, 0, 8, 3])
  2296. assert_almost_equal(r, [1.3320, -0.6653, -1.4167], decimal=4)
  2297. assert_almost_equal(p, [-0.4093, -1.1644, 1.5737], decimal=4)
  2298. assert_almost_equal(k, [-1.2500], decimal=4)
  2299. r, p, k = residue([-4, 8], [1, 6, 8])
  2300. assert_almost_equal(r, [8, -12])
  2301. assert_almost_equal(p, [-2, -4])
  2302. assert_equal(k.size, 0)
  2303. r, p, k = residue([4, 1], [1, -1, -2])
  2304. assert_almost_equal(r, [1, 3])
  2305. assert_almost_equal(p, [-1, 2])
  2306. assert_equal(k.size, 0)
  2307. r, p, k = residue([4, 3], [2, -3.4, 1.98, -0.406])
  2308. self.assert_rp_almost_equal(
  2309. r, p, [-18.125 - 13.125j, -18.125 + 13.125j, 36.25],
  2310. [0.5 - 0.2j, 0.5 + 0.2j, 0.7])
  2311. assert_equal(k.size, 0)
  2312. r, p, k = residue([2, 1], [1, 5, 8, 4])
  2313. self.assert_rp_almost_equal(r, p, [-1, 1, 3], [-1, -2, -2])
  2314. assert_equal(k.size, 0)
  2315. r, p, k = residue([3, -1.1, 0.88, -2.396, 1.348],
  2316. [1, -0.7, -0.14, 0.048])
  2317. assert_almost_equal(r, [-3, 4, 1])
  2318. assert_almost_equal(p, [0.2, -0.3, 0.8])
  2319. assert_almost_equal(k, [3, 1])
  2320. r, p, k = residue([1], [1, 2, -3])
  2321. assert_almost_equal(r, [0.25, -0.25])
  2322. assert_almost_equal(p, [1, -3])
  2323. assert_equal(k.size, 0)
  2324. r, p, k = residue([1, 0, -5], [1, 0, 0, 0, -1])
  2325. self.assert_rp_almost_equal(r, p,
  2326. [1, 1.5j, -1.5j, -1], [-1, -1j, 1j, 1])
  2327. assert_equal(k.size, 0)
  2328. r, p, k = residue([3, 8, 6], [1, 3, 3, 1])
  2329. self.assert_rp_almost_equal(r, p, [1, 2, 3], [-1, -1, -1])
  2330. assert_equal(k.size, 0)
  2331. r, p, k = residue([3, -1], [1, -3, 2])
  2332. assert_almost_equal(r, [-2, 5])
  2333. assert_almost_equal(p, [1, 2])
  2334. assert_equal(k.size, 0)
  2335. r, p, k = residue([2, 3, -1], [1, -3, 2])
  2336. assert_almost_equal(r, [-4, 13])
  2337. assert_almost_equal(p, [1, 2])
  2338. assert_almost_equal(k, [2])
  2339. r, p, k = residue([7, 2, 3, -1], [1, -3, 2])
  2340. assert_almost_equal(r, [-11, 69])
  2341. assert_almost_equal(p, [1, 2])
  2342. assert_almost_equal(k, [7, 23])
  2343. r, p, k = residue([2, 3, -1], [1, -3, 4, -2])
  2344. self.assert_rp_almost_equal(r, p, [4, -1 + 3.5j, -1 - 3.5j],
  2345. [1, 1 - 1j, 1 + 1j])
  2346. assert_almost_equal(k.size, 0)
  2347. def test_residue_leading_zeros(self):
  2348. # Leading zeros in numerator or denominator must not affect the answer.
  2349. r0, p0, k0 = residue([5, 3, -2, 7], [-4, 0, 8, 3])
  2350. r1, p1, k1 = residue([0, 5, 3, -2, 7], [-4, 0, 8, 3])
  2351. r2, p2, k2 = residue([5, 3, -2, 7], [0, -4, 0, 8, 3])
  2352. r3, p3, k3 = residue([0, 0, 5, 3, -2, 7], [0, 0, 0, -4, 0, 8, 3])
  2353. assert_almost_equal(r0, r1)
  2354. assert_almost_equal(r0, r2)
  2355. assert_almost_equal(r0, r3)
  2356. assert_almost_equal(p0, p1)
  2357. assert_almost_equal(p0, p2)
  2358. assert_almost_equal(p0, p3)
  2359. assert_almost_equal(k0, k1)
  2360. assert_almost_equal(k0, k2)
  2361. assert_almost_equal(k0, k3)
  2362. def test_resiude_degenerate(self):
  2363. # Several tests for zero numerator and denominator.
  2364. r, p, k = residue([0, 0], [1, 6, 8])
  2365. assert_almost_equal(r, [0, 0])
  2366. assert_almost_equal(p, [-2, -4])
  2367. assert_equal(k.size, 0)
  2368. r, p, k = residue(0, 1)
  2369. assert_equal(r.size, 0)
  2370. assert_equal(p.size, 0)
  2371. assert_equal(k.size, 0)
  2372. with pytest.raises(ValueError, match="Denominator `a` is zero."):
  2373. residue(1, 0)
  2374. def test_residuez_general(self):
  2375. r, p, k = residuez([1, 6, 6, 2], [1, -(2 + 1j), (1 + 2j), -1j])
  2376. self.assert_rp_almost_equal(r, p, [-2+2.5j, 7.5+7.5j, -4.5-12j],
  2377. [1j, 1, 1])
  2378. assert_almost_equal(k, [2j])
  2379. r, p, k = residuez([1, 2, 1], [1, -1, 0.3561])
  2380. self.assert_rp_almost_equal(r, p,
  2381. [-0.9041 - 5.9928j, -0.9041 + 5.9928j],
  2382. [0.5 + 0.3257j, 0.5 - 0.3257j],
  2383. decimal=4)
  2384. assert_almost_equal(k, [2.8082], decimal=4)
  2385. r, p, k = residuez([1, -1], [1, -5, 6])
  2386. assert_almost_equal(r, [-1, 2])
  2387. assert_almost_equal(p, [2, 3])
  2388. assert_equal(k.size, 0)
  2389. r, p, k = residuez([2, 3, 4], [1, 3, 3, 1])
  2390. self.assert_rp_almost_equal(r, p, [4, -5, 3], [-1, -1, -1])
  2391. assert_equal(k.size, 0)
  2392. r, p, k = residuez([1, -10, -4, 4], [2, -2, -4])
  2393. assert_almost_equal(r, [0.5, -1.5])
  2394. assert_almost_equal(p, [-1, 2])
  2395. assert_almost_equal(k, [1.5, -1])
  2396. r, p, k = residuez([18], [18, 3, -4, -1])
  2397. self.assert_rp_almost_equal(r, p,
  2398. [0.36, 0.24, 0.4], [0.5, -1/3, -1/3])
  2399. assert_equal(k.size, 0)
  2400. r, p, k = residuez([2, 3], np.polymul([1, -1/2], [1, 1/4]))
  2401. assert_almost_equal(r, [-10/3, 16/3])
  2402. assert_almost_equal(p, [-0.25, 0.5])
  2403. assert_equal(k.size, 0)
  2404. r, p, k = residuez([1, -2, 1], [1, -1])
  2405. assert_almost_equal(r, [0])
  2406. assert_almost_equal(p, [1])
  2407. assert_almost_equal(k, [1, -1])
  2408. r, p, k = residuez(1, [1, -1j])
  2409. assert_almost_equal(r, [1])
  2410. assert_almost_equal(p, [1j])
  2411. assert_equal(k.size, 0)
  2412. r, p, k = residuez(1, [1, -1, 0.25])
  2413. assert_almost_equal(r, [0, 1])
  2414. assert_almost_equal(p, [0.5, 0.5])
  2415. assert_equal(k.size, 0)
  2416. r, p, k = residuez(1, [1, -0.75, .125])
  2417. assert_almost_equal(r, [-1, 2])
  2418. assert_almost_equal(p, [0.25, 0.5])
  2419. assert_equal(k.size, 0)
  2420. r, p, k = residuez([1, 6, 2], [1, -2, 1])
  2421. assert_almost_equal(r, [-10, 9])
  2422. assert_almost_equal(p, [1, 1])
  2423. assert_almost_equal(k, [2])
  2424. r, p, k = residuez([6, 2], [1, -2, 1])
  2425. assert_almost_equal(r, [-2, 8])
  2426. assert_almost_equal(p, [1, 1])
  2427. assert_equal(k.size, 0)
  2428. r, p, k = residuez([1, 6, 6, 2], [1, -2, 1])
  2429. assert_almost_equal(r, [-24, 15])
  2430. assert_almost_equal(p, [1, 1])
  2431. assert_almost_equal(k, [10, 2])
  2432. r, p, k = residuez([1, 0, 1], [1, 0, 0, 0, 0, -1])
  2433. self.assert_rp_almost_equal(r, p,
  2434. [0.2618 + 0.1902j, 0.2618 - 0.1902j,
  2435. 0.4, 0.0382 - 0.1176j, 0.0382 + 0.1176j],
  2436. [-0.8090 + 0.5878j, -0.8090 - 0.5878j,
  2437. 1.0, 0.3090 + 0.9511j, 0.3090 - 0.9511j],
  2438. decimal=4)
  2439. assert_equal(k.size, 0)
  2440. def test_residuez_trailing_zeros(self):
  2441. # Trailing zeros in numerator or denominator must not affect the
  2442. # answer.
  2443. r0, p0, k0 = residuez([5, 3, -2, 7], [-4, 0, 8, 3])
  2444. r1, p1, k1 = residuez([5, 3, -2, 7, 0], [-4, 0, 8, 3])
  2445. r2, p2, k2 = residuez([5, 3, -2, 7], [-4, 0, 8, 3, 0])
  2446. r3, p3, k3 = residuez([5, 3, -2, 7, 0, 0], [-4, 0, 8, 3, 0, 0, 0])
  2447. assert_almost_equal(r0, r1)
  2448. assert_almost_equal(r0, r2)
  2449. assert_almost_equal(r0, r3)
  2450. assert_almost_equal(p0, p1)
  2451. assert_almost_equal(p0, p2)
  2452. assert_almost_equal(p0, p3)
  2453. assert_almost_equal(k0, k1)
  2454. assert_almost_equal(k0, k2)
  2455. assert_almost_equal(k0, k3)
  2456. def test_residuez_degenerate(self):
  2457. r, p, k = residuez([0, 0], [1, 6, 8])
  2458. assert_almost_equal(r, [0, 0])
  2459. assert_almost_equal(p, [-2, -4])
  2460. assert_equal(k.size, 0)
  2461. r, p, k = residuez(0, 1)
  2462. assert_equal(r.size, 0)
  2463. assert_equal(p.size, 0)
  2464. assert_equal(k.size, 0)
  2465. with pytest.raises(ValueError, match="Denominator `a` is zero."):
  2466. residuez(1, 0)
  2467. with pytest.raises(ValueError,
  2468. match="First coefficient of determinant `a` must "
  2469. "be non-zero."):
  2470. residuez(1, [0, 1, 2, 3])
  2471. def test_inverse_unique_roots_different_rtypes(self):
  2472. # This test was inspired by github issue 2496.
  2473. r = [3 / 10, -1 / 6, -2 / 15]
  2474. p = [0, -2, -5]
  2475. k = []
  2476. b_expected = [0, 1, 3]
  2477. a_expected = [1, 7, 10, 0]
  2478. # With the default tolerance, the rtype does not matter
  2479. # for this example.
  2480. for rtype in ('avg', 'mean', 'min', 'minimum', 'max', 'maximum'):
  2481. b, a = invres(r, p, k, rtype=rtype)
  2482. assert_allclose(b, b_expected)
  2483. assert_allclose(a, a_expected)
  2484. b, a = invresz(r, p, k, rtype=rtype)
  2485. assert_allclose(b, b_expected)
  2486. assert_allclose(a, a_expected)
  2487. def test_inverse_repeated_roots_different_rtypes(self):
  2488. r = [3 / 20, -7 / 36, -1 / 6, 2 / 45]
  2489. p = [0, -2, -2, -5]
  2490. k = []
  2491. b_expected = [0, 0, 1, 3]
  2492. b_expected_z = [-1/6, -2/3, 11/6, 3]
  2493. a_expected = [1, 9, 24, 20, 0]
  2494. for rtype in ('avg', 'mean', 'min', 'minimum', 'max', 'maximum'):
  2495. b, a = invres(r, p, k, rtype=rtype)
  2496. assert_allclose(b, b_expected, atol=1e-14)
  2497. assert_allclose(a, a_expected)
  2498. b, a = invresz(r, p, k, rtype=rtype)
  2499. assert_allclose(b, b_expected_z, atol=1e-14)
  2500. assert_allclose(a, a_expected)
  2501. def test_inverse_bad_rtype(self):
  2502. r = [3 / 20, -7 / 36, -1 / 6, 2 / 45]
  2503. p = [0, -2, -2, -5]
  2504. k = []
  2505. with pytest.raises(ValueError, match="`rtype` must be one of"):
  2506. invres(r, p, k, rtype='median')
  2507. with pytest.raises(ValueError, match="`rtype` must be one of"):
  2508. invresz(r, p, k, rtype='median')
  2509. def test_invresz_one_coefficient_bug(self):
  2510. # Regression test for issue in gh-4646.
  2511. r = [1]
  2512. p = [2]
  2513. k = [0]
  2514. b, a = invresz(r, p, k)
  2515. assert_allclose(b, [1.0])
  2516. assert_allclose(a, [1.0, -2.0])
  2517. def test_invres(self):
  2518. b, a = invres([1], [1], [])
  2519. assert_almost_equal(b, [1])
  2520. assert_almost_equal(a, [1, -1])
  2521. b, a = invres([1 - 1j, 2, 0.5 - 3j], [1, 0.5j, 1 + 1j], [])
  2522. assert_almost_equal(b, [3.5 - 4j, -8.5 + 0.25j, 3.5 + 3.25j])
  2523. assert_almost_equal(a, [1, -2 - 1.5j, 0.5 + 2j, 0.5 - 0.5j])
  2524. b, a = invres([0.5, 1], [1 - 1j, 2 + 2j], [1, 2, 3])
  2525. assert_almost_equal(b, [1, -1 - 1j, 1 - 2j, 0.5 - 3j, 10])
  2526. assert_almost_equal(a, [1, -3 - 1j, 4])
  2527. b, a = invres([-1, 2, 1j, 3 - 1j, 4, -2],
  2528. [-1, 2 - 1j, 2 - 1j, 3, 3, 3], [])
  2529. assert_almost_equal(b, [4 - 1j, -28 + 16j, 40 - 62j, 100 + 24j,
  2530. -292 + 219j, 192 - 268j])
  2531. assert_almost_equal(a, [1, -12 + 2j, 53 - 20j, -96 + 68j, 27 - 72j,
  2532. 108 - 54j, -81 + 108j])
  2533. b, a = invres([-1, 1j], [1, 1], [1, 2])
  2534. assert_almost_equal(b, [1, 0, -4, 3 + 1j])
  2535. assert_almost_equal(a, [1, -2, 1])
  2536. def test_invresz(self):
  2537. b, a = invresz([1], [1], [])
  2538. assert_almost_equal(b, [1])
  2539. assert_almost_equal(a, [1, -1])
  2540. b, a = invresz([1 - 1j, 2, 0.5 - 3j], [1, 0.5j, 1 + 1j], [])
  2541. assert_almost_equal(b, [3.5 - 4j, -8.5 + 0.25j, 3.5 + 3.25j])
  2542. assert_almost_equal(a, [1, -2 - 1.5j, 0.5 + 2j, 0.5 - 0.5j])
  2543. b, a = invresz([0.5, 1], [1 - 1j, 2 + 2j], [1, 2, 3])
  2544. assert_almost_equal(b, [2.5, -3 - 1j, 1 - 2j, -1 - 3j, 12])
  2545. assert_almost_equal(a, [1, -3 - 1j, 4])
  2546. b, a = invresz([-1, 2, 1j, 3 - 1j, 4, -2],
  2547. [-1, 2 - 1j, 2 - 1j, 3, 3, 3], [])
  2548. assert_almost_equal(b, [6, -50 + 11j, 100 - 72j, 80 + 58j,
  2549. -354 + 228j, 234 - 297j])
  2550. assert_almost_equal(a, [1, -12 + 2j, 53 - 20j, -96 + 68j, 27 - 72j,
  2551. 108 - 54j, -81 + 108j])
  2552. b, a = invresz([-1, 1j], [1, 1], [1, 2])
  2553. assert_almost_equal(b, [1j, 1, -3, 2])
  2554. assert_almost_equal(a, [1, -2, 1])
  2555. def test_inverse_scalar_arguments(self):
  2556. b, a = invres(1, 1, 1)
  2557. assert_almost_equal(b, [1, 0])
  2558. assert_almost_equal(a, [1, -1])
  2559. b, a = invresz(1, 1, 1)
  2560. assert_almost_equal(b, [2, -1])
  2561. assert_almost_equal(a, [1, -1])
  2562. class TestVectorstrength:
  2563. def test_single_1dperiod(self):
  2564. events = np.array([.5])
  2565. period = 5.
  2566. targ_strength = 1.
  2567. targ_phase = .1
  2568. strength, phase = vectorstrength(events, period)
  2569. assert_equal(strength.ndim, 0)
  2570. assert_equal(phase.ndim, 0)
  2571. assert_almost_equal(strength, targ_strength)
  2572. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  2573. def test_single_2dperiod(self):
  2574. events = np.array([.5])
  2575. period = [1, 2, 5.]
  2576. targ_strength = [1.] * 3
  2577. targ_phase = np.array([.5, .25, .1])
  2578. strength, phase = vectorstrength(events, period)
  2579. assert_equal(strength.ndim, 1)
  2580. assert_equal(phase.ndim, 1)
  2581. assert_array_almost_equal(strength, targ_strength)
  2582. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  2583. def test_equal_1dperiod(self):
  2584. events = np.array([.25, .25, .25, .25, .25, .25])
  2585. period = 2
  2586. targ_strength = 1.
  2587. targ_phase = .125
  2588. strength, phase = vectorstrength(events, period)
  2589. assert_equal(strength.ndim, 0)
  2590. assert_equal(phase.ndim, 0)
  2591. assert_almost_equal(strength, targ_strength)
  2592. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  2593. def test_equal_2dperiod(self):
  2594. events = np.array([.25, .25, .25, .25, .25, .25])
  2595. period = [1, 2, ]
  2596. targ_strength = [1.] * 2
  2597. targ_phase = np.array([.25, .125])
  2598. strength, phase = vectorstrength(events, period)
  2599. assert_equal(strength.ndim, 1)
  2600. assert_equal(phase.ndim, 1)
  2601. assert_almost_equal(strength, targ_strength)
  2602. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  2603. def test_spaced_1dperiod(self):
  2604. events = np.array([.1, 1.1, 2.1, 4.1, 10.1])
  2605. period = 1
  2606. targ_strength = 1.
  2607. targ_phase = .1
  2608. strength, phase = vectorstrength(events, period)
  2609. assert_equal(strength.ndim, 0)
  2610. assert_equal(phase.ndim, 0)
  2611. assert_almost_equal(strength, targ_strength)
  2612. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  2613. def test_spaced_2dperiod(self):
  2614. events = np.array([.1, 1.1, 2.1, 4.1, 10.1])
  2615. period = [1, .5]
  2616. targ_strength = [1.] * 2
  2617. targ_phase = np.array([.1, .2])
  2618. strength, phase = vectorstrength(events, period)
  2619. assert_equal(strength.ndim, 1)
  2620. assert_equal(phase.ndim, 1)
  2621. assert_almost_equal(strength, targ_strength)
  2622. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  2623. def test_partial_1dperiod(self):
  2624. events = np.array([.25, .5, .75])
  2625. period = 1
  2626. targ_strength = 1. / 3.
  2627. targ_phase = .5
  2628. strength, phase = vectorstrength(events, period)
  2629. assert_equal(strength.ndim, 0)
  2630. assert_equal(phase.ndim, 0)
  2631. assert_almost_equal(strength, targ_strength)
  2632. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  2633. def test_partial_2dperiod(self):
  2634. events = np.array([.25, .5, .75])
  2635. period = [1., 1., 1., 1.]
  2636. targ_strength = [1. / 3.] * 4
  2637. targ_phase = np.array([.5, .5, .5, .5])
  2638. strength, phase = vectorstrength(events, period)
  2639. assert_equal(strength.ndim, 1)
  2640. assert_equal(phase.ndim, 1)
  2641. assert_almost_equal(strength, targ_strength)
  2642. assert_almost_equal(phase, 2 * np.pi * targ_phase)
  2643. def test_opposite_1dperiod(self):
  2644. events = np.array([0, .25, .5, .75])
  2645. period = 1.
  2646. targ_strength = 0
  2647. strength, phase = vectorstrength(events, period)
  2648. assert_equal(strength.ndim, 0)
  2649. assert_equal(phase.ndim, 0)
  2650. assert_almost_equal(strength, targ_strength)
  2651. def test_opposite_2dperiod(self):
  2652. events = np.array([0, .25, .5, .75])
  2653. period = [1.] * 10
  2654. targ_strength = [0.] * 10
  2655. strength, phase = vectorstrength(events, period)
  2656. assert_equal(strength.ndim, 1)
  2657. assert_equal(phase.ndim, 1)
  2658. assert_almost_equal(strength, targ_strength)
  2659. def test_2d_events_ValueError(self):
  2660. events = np.array([[1, 2]])
  2661. period = 1.
  2662. assert_raises(ValueError, vectorstrength, events, period)
  2663. def test_2d_period_ValueError(self):
  2664. events = 1.
  2665. period = np.array([[1]])
  2666. assert_raises(ValueError, vectorstrength, events, period)
  2667. def test_zero_period_ValueError(self):
  2668. events = 1.
  2669. period = 0
  2670. assert_raises(ValueError, vectorstrength, events, period)
  2671. def test_negative_period_ValueError(self):
  2672. events = 1.
  2673. period = -1
  2674. assert_raises(ValueError, vectorstrength, events, period)
  2675. def cast_tf2sos(b, a):
  2676. """Convert TF2SOS, casting to complex128 and back to the original dtype."""
  2677. # tf2sos does not support all of the dtypes that we want to check, e.g.:
  2678. #
  2679. # TypeError: array type complex256 is unsupported in linalg
  2680. #
  2681. # so let's cast, convert, and cast back -- should be fine for the
  2682. # systems and precisions we are testing.
  2683. dtype = np.asarray(b).dtype
  2684. b = np.array(b, np.complex128)
  2685. a = np.array(a, np.complex128)
  2686. return tf2sos(b, a).astype(dtype)
  2687. def assert_allclose_cast(actual, desired, rtol=1e-7, atol=0):
  2688. """Wrap assert_allclose while casting object arrays."""
  2689. if actual.dtype.kind == 'O':
  2690. dtype = np.array(actual.flat[0]).dtype
  2691. actual, desired = actual.astype(dtype), desired.astype(dtype)
  2692. assert_allclose(actual, desired, rtol, atol)
  2693. @pytest.mark.parametrize('func', (sosfilt, lfilter))
  2694. def test_nonnumeric_dtypes(func):
  2695. x = [Decimal(1), Decimal(2), Decimal(3)]
  2696. b = [Decimal(1), Decimal(2), Decimal(3)]
  2697. a = [Decimal(1), Decimal(2), Decimal(3)]
  2698. x = np.array(x)
  2699. assert x.dtype.kind == 'O'
  2700. desired = lfilter(np.array(b, float), np.array(a, float), x.astype(float))
  2701. if func is sosfilt:
  2702. actual = sosfilt([b + a], x)
  2703. else:
  2704. actual = lfilter(b, a, x)
  2705. assert all(isinstance(x, Decimal) for x in actual)
  2706. assert_allclose(actual.astype(float), desired.astype(float))
  2707. # Degenerate cases
  2708. if func is lfilter:
  2709. args = [1., 1.]
  2710. else:
  2711. args = [tf2sos(1., 1.)]
  2712. with pytest.raises(ValueError, match='must be at least 1-D'):
  2713. func(*args, x=1.)
  2714. @pytest.mark.parametrize('dt', 'fdgFDGO')
  2715. class TestSOSFilt:
  2716. # The test_rank* tests are pulled from _TestLinearFilter
  2717. def test_rank1(self, dt):
  2718. x = np.linspace(0, 5, 6).astype(dt)
  2719. b = np.array([1, -1]).astype(dt)
  2720. a = np.array([0.5, -0.5]).astype(dt)
  2721. # Test simple IIR
  2722. y_r = np.array([0, 2, 4, 6, 8, 10.]).astype(dt)
  2723. sos = cast_tf2sos(b, a)
  2724. assert sos.dtype.char == dt
  2725. assert_array_almost_equal(sosfilt(cast_tf2sos(b, a), x), y_r)
  2726. # Test simple FIR
  2727. b = np.array([1, 1]).astype(dt)
  2728. # NOTE: This was changed (rel. to TestLinear...) to add a pole @zero:
  2729. a = np.array([1, 0]).astype(dt)
  2730. y_r = np.array([0, 1, 3, 5, 7, 9.]).astype(dt)
  2731. assert_array_almost_equal(sosfilt(cast_tf2sos(b, a), x), y_r)
  2732. b = [1, 1, 0]
  2733. a = [1, 0, 0]
  2734. x = np.ones(8)
  2735. sos = np.concatenate((b, a))
  2736. sos.shape = (1, 6)
  2737. y = sosfilt(sos, x)
  2738. assert_allclose(y, [1, 2, 2, 2, 2, 2, 2, 2])
  2739. def test_rank2(self, dt):
  2740. shape = (4, 3)
  2741. x = np.linspace(0, np.prod(shape) - 1, np.prod(shape)).reshape(shape)
  2742. x = x.astype(dt)
  2743. b = np.array([1, -1]).astype(dt)
  2744. a = np.array([0.5, 0.5]).astype(dt)
  2745. y_r2_a0 = np.array([[0, 2, 4], [6, 4, 2], [0, 2, 4], [6, 4, 2]],
  2746. dtype=dt)
  2747. y_r2_a1 = np.array([[0, 2, 0], [6, -4, 6], [12, -10, 12],
  2748. [18, -16, 18]], dtype=dt)
  2749. y = sosfilt(cast_tf2sos(b, a), x, axis=0)
  2750. assert_array_almost_equal(y_r2_a0, y)
  2751. y = sosfilt(cast_tf2sos(b, a), x, axis=1)
  2752. assert_array_almost_equal(y_r2_a1, y)
  2753. def test_rank3(self, dt):
  2754. shape = (4, 3, 2)
  2755. x = np.linspace(0, np.prod(shape) - 1, np.prod(shape)).reshape(shape)
  2756. b = np.array([1, -1]).astype(dt)
  2757. a = np.array([0.5, 0.5]).astype(dt)
  2758. # Test last axis
  2759. y = sosfilt(cast_tf2sos(b, a), x)
  2760. for i in range(x.shape[0]):
  2761. for j in range(x.shape[1]):
  2762. assert_array_almost_equal(y[i, j], lfilter(b, a, x[i, j]))
  2763. def test_initial_conditions(self, dt):
  2764. b1, a1 = signal.butter(2, 0.25, 'low')
  2765. b2, a2 = signal.butter(2, 0.75, 'low')
  2766. b3, a3 = signal.butter(2, 0.75, 'low')
  2767. b = np.convolve(np.convolve(b1, b2), b3)
  2768. a = np.convolve(np.convolve(a1, a2), a3)
  2769. sos = np.array((np.r_[b1, a1], np.r_[b2, a2], np.r_[b3, a3]))
  2770. x = np.random.rand(50).astype(dt)
  2771. # Stopping filtering and continuing
  2772. y_true, zi = lfilter(b, a, x[:20], zi=np.zeros(6))
  2773. y_true = np.r_[y_true, lfilter(b, a, x[20:], zi=zi)[0]]
  2774. assert_allclose_cast(y_true, lfilter(b, a, x))
  2775. y_sos, zi = sosfilt(sos, x[:20], zi=np.zeros((3, 2)))
  2776. y_sos = np.r_[y_sos, sosfilt(sos, x[20:], zi=zi)[0]]
  2777. assert_allclose_cast(y_true, y_sos)
  2778. # Use a step function
  2779. zi = sosfilt_zi(sos)
  2780. x = np.ones(8, dt)
  2781. y, zf = sosfilt(sos, x, zi=zi)
  2782. assert_allclose_cast(y, np.ones(8))
  2783. assert_allclose_cast(zf, zi)
  2784. # Initial condition shape matching
  2785. x.shape = (1, 1) + x.shape # 3D
  2786. assert_raises(ValueError, sosfilt, sos, x, zi=zi)
  2787. zi_nd = zi.copy()
  2788. zi_nd.shape = (zi.shape[0], 1, 1, zi.shape[-1])
  2789. assert_raises(ValueError, sosfilt, sos, x,
  2790. zi=zi_nd[:, :, :, [0, 1, 1]])
  2791. y, zf = sosfilt(sos, x, zi=zi_nd)
  2792. assert_allclose_cast(y[0, 0], np.ones(8))
  2793. assert_allclose_cast(zf[:, 0, 0, :], zi)
  2794. def test_initial_conditions_3d_axis1(self, dt):
  2795. # Test the use of zi when sosfilt is applied to axis 1 of a 3-d input.
  2796. # Input array is x.
  2797. x = np.random.RandomState(159).randint(0, 5, size=(2, 15, 3))
  2798. x = x.astype(dt)
  2799. # Design a filter in ZPK format and convert to SOS
  2800. zpk = signal.butter(6, 0.35, output='zpk')
  2801. sos = zpk2sos(*zpk)
  2802. nsections = sos.shape[0]
  2803. # Filter along this axis.
  2804. axis = 1
  2805. # Initial conditions, all zeros.
  2806. shp = list(x.shape)
  2807. shp[axis] = 2
  2808. shp = [nsections] + shp
  2809. z0 = np.zeros(shp)
  2810. # Apply the filter to x.
  2811. yf, zf = sosfilt(sos, x, axis=axis, zi=z0)
  2812. # Apply the filter to x in two stages.
  2813. y1, z1 = sosfilt(sos, x[:, :5, :], axis=axis, zi=z0)
  2814. y2, z2 = sosfilt(sos, x[:, 5:, :], axis=axis, zi=z1)
  2815. # y should equal yf, and z2 should equal zf.
  2816. y = np.concatenate((y1, y2), axis=axis)
  2817. assert_allclose_cast(y, yf, rtol=1e-10, atol=1e-13)
  2818. assert_allclose_cast(z2, zf, rtol=1e-10, atol=1e-13)
  2819. # let's try the "step" initial condition
  2820. zi = sosfilt_zi(sos)
  2821. zi.shape = [nsections, 1, 2, 1]
  2822. zi = zi * x[:, 0:1, :]
  2823. y = sosfilt(sos, x, axis=axis, zi=zi)[0]
  2824. # check it against the TF form
  2825. b, a = zpk2tf(*zpk)
  2826. zi = lfilter_zi(b, a)
  2827. zi.shape = [1, zi.size, 1]
  2828. zi = zi * x[:, 0:1, :]
  2829. y_tf = lfilter(b, a, x, axis=axis, zi=zi)[0]
  2830. assert_allclose_cast(y, y_tf, rtol=1e-10, atol=1e-13)
  2831. def test_bad_zi_shape(self, dt):
  2832. # The shape of zi is checked before using any values in the
  2833. # arguments, so np.empty is fine for creating the arguments.
  2834. x = np.empty((3, 15, 3), dt)
  2835. sos = np.zeros((4, 6))
  2836. zi = np.empty((4, 3, 3, 2)) # Correct shape is (4, 3, 2, 3)
  2837. with pytest.raises(ValueError, match='should be all ones'):
  2838. sosfilt(sos, x, zi=zi, axis=1)
  2839. sos[:, 3] = 1.
  2840. with pytest.raises(ValueError, match='Invalid zi shape'):
  2841. sosfilt(sos, x, zi=zi, axis=1)
  2842. def test_sosfilt_zi(self, dt):
  2843. sos = signal.butter(6, 0.2, output='sos')
  2844. zi = sosfilt_zi(sos)
  2845. y, zf = sosfilt(sos, np.ones(40, dt), zi=zi)
  2846. assert_allclose_cast(zf, zi, rtol=1e-13)
  2847. # Expected steady state value of the step response of this filter:
  2848. ss = np.prod(sos[:, :3].sum(axis=-1) / sos[:, 3:].sum(axis=-1))
  2849. assert_allclose_cast(y, ss, rtol=1e-13)
  2850. # zi as array-like
  2851. _, zf = sosfilt(sos, np.ones(40, dt), zi=zi.tolist())
  2852. assert_allclose_cast(zf, zi, rtol=1e-13)
  2853. class TestDeconvolve:
  2854. def test_basic(self):
  2855. # From docstring example
  2856. original = [0, 1, 0, 0, 1, 1, 0, 0]
  2857. impulse_response = [2, 1]
  2858. recorded = [0, 2, 1, 0, 2, 3, 1, 0, 0]
  2859. recovered, remainder = signal.deconvolve(recorded, impulse_response)
  2860. assert_allclose(recovered, original)
  2861. def test_n_dimensional_signal(self):
  2862. recorded = [[0, 0], [0, 0]]
  2863. impulse_response = [0, 0]
  2864. with pytest.raises(ValueError, match="signal must be 1-D."):
  2865. quotient, remainder = signal.deconvolve(recorded, impulse_response)
  2866. def test_n_dimensional_divisor(self):
  2867. recorded = [0, 0]
  2868. impulse_response = [[0, 0], [0, 0]]
  2869. with pytest.raises(ValueError, match="divisor must be 1-D."):
  2870. quotient, remainder = signal.deconvolve(recorded, impulse_response)
  2871. class TestDetrend:
  2872. def test_basic(self):
  2873. detrended = detrend(array([1, 2, 3]))
  2874. detrended_exact = array([0, 0, 0])
  2875. assert_array_almost_equal(detrended, detrended_exact)
  2876. def test_copy(self):
  2877. x = array([1, 1.2, 1.5, 1.6, 2.4])
  2878. copy_array = detrend(x, overwrite_data=False)
  2879. inplace = detrend(x, overwrite_data=True)
  2880. assert_array_almost_equal(copy_array, inplace)
  2881. class TestUniqueRoots:
  2882. def test_real_no_repeat(self):
  2883. p = [-1.0, -0.5, 0.3, 1.2, 10.0]
  2884. unique, multiplicity = unique_roots(p)
  2885. assert_almost_equal(unique, p, decimal=15)
  2886. assert_equal(multiplicity, np.ones(len(p)))
  2887. def test_real_repeat(self):
  2888. p = [-1.0, -0.95, -0.89, -0.8, 0.5, 1.0, 1.05]
  2889. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='min')
  2890. assert_almost_equal(unique, [-1.0, -0.89, 0.5, 1.0], decimal=15)
  2891. assert_equal(multiplicity, [2, 2, 1, 2])
  2892. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='max')
  2893. assert_almost_equal(unique, [-0.95, -0.8, 0.5, 1.05], decimal=15)
  2894. assert_equal(multiplicity, [2, 2, 1, 2])
  2895. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='avg')
  2896. assert_almost_equal(unique, [-0.975, -0.845, 0.5, 1.025], decimal=15)
  2897. assert_equal(multiplicity, [2, 2, 1, 2])
  2898. def test_complex_no_repeat(self):
  2899. p = [-1.0, 1.0j, 0.5 + 0.5j, -1.0 - 1.0j, 3.0 + 2.0j]
  2900. unique, multiplicity = unique_roots(p)
  2901. assert_almost_equal(unique, p, decimal=15)
  2902. assert_equal(multiplicity, np.ones(len(p)))
  2903. def test_complex_repeat(self):
  2904. p = [-1.0, -1.0 + 0.05j, -0.95 + 0.15j, -0.90 + 0.15j, 0.0,
  2905. 0.5 + 0.5j, 0.45 + 0.55j]
  2906. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='min')
  2907. assert_almost_equal(unique, [-1.0, -0.95 + 0.15j, 0.0, 0.45 + 0.55j],
  2908. decimal=15)
  2909. assert_equal(multiplicity, [2, 2, 1, 2])
  2910. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='max')
  2911. assert_almost_equal(unique,
  2912. [-1.0 + 0.05j, -0.90 + 0.15j, 0.0, 0.5 + 0.5j],
  2913. decimal=15)
  2914. assert_equal(multiplicity, [2, 2, 1, 2])
  2915. unique, multiplicity = unique_roots(p, tol=1e-1, rtype='avg')
  2916. assert_almost_equal(
  2917. unique, [-1.0 + 0.025j, -0.925 + 0.15j, 0.0, 0.475 + 0.525j],
  2918. decimal=15)
  2919. assert_equal(multiplicity, [2, 2, 1, 2])
  2920. def test_gh_4915(self):
  2921. p = np.roots(np.convolve(np.ones(5), np.ones(5)))
  2922. true_roots = [-(-1)**(1/5), (-1)**(4/5), -(-1)**(3/5), (-1)**(2/5)]
  2923. unique, multiplicity = unique_roots(p)
  2924. unique = np.sort(unique)
  2925. assert_almost_equal(np.sort(unique), true_roots, decimal=7)
  2926. assert_equal(multiplicity, [2, 2, 2, 2])
  2927. def test_complex_roots_extra(self):
  2928. unique, multiplicity = unique_roots([1.0, 1.0j, 1.0])
  2929. assert_almost_equal(unique, [1.0, 1.0j], decimal=15)
  2930. assert_equal(multiplicity, [2, 1])
  2931. unique, multiplicity = unique_roots([1, 1 + 2e-9, 1e-9 + 1j], tol=0.1)
  2932. assert_almost_equal(unique, [1.0, 1e-9 + 1.0j], decimal=15)
  2933. assert_equal(multiplicity, [2, 1])
  2934. def test_single_unique_root(self):
  2935. p = np.random.rand(100) + 1j * np.random.rand(100)
  2936. unique, multiplicity = unique_roots(p, 2)
  2937. assert_almost_equal(unique, [np.min(p)], decimal=15)
  2938. assert_equal(multiplicity, [100])