12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700 |
- import itertools
- import numpy as np
- from numpy.testing import assert_, assert_allclose, assert_equal
- from pytest import raises as assert_raises
- from scipy import linalg
- import scipy.linalg._decomp_update as _decomp_update
- from scipy.linalg._decomp_update import qr_delete, qr_update, qr_insert
- def assert_unitary(a, rtol=None, atol=None, assert_sqr=True):
- if rtol is None:
- rtol = 10.0 ** -(np.finfo(a.dtype).precision-2)
- if atol is None:
- atol = 10*np.finfo(a.dtype).eps
- if assert_sqr:
- assert_(a.shape[0] == a.shape[1], 'unitary matrices must be square')
- aTa = np.dot(a.T.conj(), a)
- assert_allclose(aTa, np.eye(a.shape[1]), rtol=rtol, atol=atol)
- def assert_upper_tri(a, rtol=None, atol=None):
- if rtol is None:
- rtol = 10.0 ** -(np.finfo(a.dtype).precision-2)
- if atol is None:
- atol = 2*np.finfo(a.dtype).eps
- mask = np.tri(a.shape[0], a.shape[1], -1, np.bool_)
- assert_allclose(a[mask], 0.0, rtol=rtol, atol=atol)
- def check_qr(q, r, a, rtol, atol, assert_sqr=True):
- assert_unitary(q, rtol, atol, assert_sqr)
- assert_upper_tri(r, rtol, atol)
- assert_allclose(q.dot(r), a, rtol=rtol, atol=atol)
- def make_strided(arrs):
- strides = [(3, 7), (2, 2), (3, 4), (4, 2), (5, 4), (2, 3), (2, 1), (4, 5)]
- kmax = len(strides)
- k = 0
- ret = []
- for a in arrs:
- if a.ndim == 1:
- s = strides[k % kmax]
- k += 1
- base = np.zeros(s[0]*a.shape[0]+s[1], a.dtype)
- view = base[s[1]::s[0]]
- view[...] = a
- elif a.ndim == 2:
- s = strides[k % kmax]
- t = strides[(k+1) % kmax]
- k += 2
- base = np.zeros((s[0]*a.shape[0]+s[1], t[0]*a.shape[1]+t[1]),
- a.dtype)
- view = base[s[1]::s[0], t[1]::t[0]]
- view[...] = a
- else:
- raise ValueError('make_strided only works for ndim = 1 or'
- ' 2 arrays')
- ret.append(view)
- return ret
- def negate_strides(arrs):
- ret = []
- for a in arrs:
- b = np.zeros_like(a)
- if b.ndim == 2:
- b = b[::-1, ::-1]
- elif b.ndim == 1:
- b = b[::-1]
- else:
- raise ValueError('negate_strides only works for ndim = 1 or'
- ' 2 arrays')
- b[...] = a
- ret.append(b)
- return ret
- def nonitemsize_strides(arrs):
- out = []
- for a in arrs:
- a_dtype = a.dtype
- b = np.zeros(a.shape, [('a', a_dtype), ('junk', 'S1')])
- c = b.getfield(a_dtype)
- c[...] = a
- out.append(c)
- return out
- def make_nonnative(arrs):
- return [a.astype(a.dtype.newbyteorder()) for a in arrs]
- class BaseQRdeltas:
- def setup_method(self):
- self.rtol = 10.0 ** -(np.finfo(self.dtype).precision-2)
- self.atol = 10 * np.finfo(self.dtype).eps
- def generate(self, type, mode='full'):
- np.random.seed(29382)
- shape = {'sqr': (8, 8), 'tall': (12, 7), 'fat': (7, 12),
- 'Mx1': (8, 1), '1xN': (1, 8), '1x1': (1, 1)}[type]
- a = np.random.random(shape)
- if np.iscomplexobj(self.dtype.type(1)):
- b = np.random.random(shape)
- a = a + 1j * b
- a = a.astype(self.dtype)
- q, r = linalg.qr(a, mode=mode)
- return a, q, r
- class BaseQRdelete(BaseQRdeltas):
- def test_sqr_1_row(self):
- a, q, r = self.generate('sqr')
- for row in range(r.shape[0]):
- q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
- a1 = np.delete(a, row, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_sqr_p_row(self):
- a, q, r = self.generate('sqr')
- for ndel in range(2, 6):
- for row in range(a.shape[0]-ndel):
- q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
- a1 = np.delete(a, slice(row, row+ndel), 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_sqr_1_col(self):
- a, q, r = self.generate('sqr')
- for col in range(r.shape[1]):
- q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
- a1 = np.delete(a, col, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_sqr_p_col(self):
- a, q, r = self.generate('sqr')
- for ndel in range(2, 6):
- for col in range(r.shape[1]-ndel):
- q1, r1 = qr_delete(q, r, col, ndel, which='col',
- overwrite_qr=False)
- a1 = np.delete(a, slice(col, col+ndel), 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_tall_1_row(self):
- a, q, r = self.generate('tall')
- for row in range(r.shape[0]):
- q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
- a1 = np.delete(a, row, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_tall_p_row(self):
- a, q, r = self.generate('tall')
- for ndel in range(2, 6):
- for row in range(a.shape[0]-ndel):
- q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
- a1 = np.delete(a, slice(row, row+ndel), 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_tall_1_col(self):
- a, q, r = self.generate('tall')
- for col in range(r.shape[1]):
- q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
- a1 = np.delete(a, col, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_tall_p_col(self):
- a, q, r = self.generate('tall')
- for ndel in range(2, 6):
- for col in range(r.shape[1]-ndel):
- q1, r1 = qr_delete(q, r, col, ndel, which='col',
- overwrite_qr=False)
- a1 = np.delete(a, slice(col, col+ndel), 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_fat_1_row(self):
- a, q, r = self.generate('fat')
- for row in range(r.shape[0]):
- q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
- a1 = np.delete(a, row, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_fat_p_row(self):
- a, q, r = self.generate('fat')
- for ndel in range(2, 6):
- for row in range(a.shape[0]-ndel):
- q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
- a1 = np.delete(a, slice(row, row+ndel), 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_fat_1_col(self):
- a, q, r = self.generate('fat')
- for col in range(r.shape[1]):
- q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
- a1 = np.delete(a, col, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_fat_p_col(self):
- a, q, r = self.generate('fat')
- for ndel in range(2, 6):
- for col in range(r.shape[1]-ndel):
- q1, r1 = qr_delete(q, r, col, ndel, which='col',
- overwrite_qr=False)
- a1 = np.delete(a, slice(col, col+ndel), 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_economic_1_row(self):
- # this test always starts and ends with an economic decomp.
- a, q, r = self.generate('tall', 'economic')
- for row in range(r.shape[0]):
- q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
- a1 = np.delete(a, row, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- # for economic row deletes
- # eco - prow = eco
- # eco - prow = sqr
- # eco - prow = fat
- def base_economic_p_row_xxx(self, ndel):
- a, q, r = self.generate('tall', 'economic')
- for row in range(a.shape[0]-ndel):
- q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
- a1 = np.delete(a, slice(row, row+ndel), 0)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_economic_p_row_economic(self):
- # (12, 7) - (3, 7) = (9,7) --> stays economic
- self.base_economic_p_row_xxx(3)
- def test_economic_p_row_sqr(self):
- # (12, 7) - (5, 7) = (7, 7) --> becomes square
- self.base_economic_p_row_xxx(5)
- def test_economic_p_row_fat(self):
- # (12, 7) - (7,7) = (5, 7) --> becomes fat
- self.base_economic_p_row_xxx(7)
- def test_economic_1_col(self):
- a, q, r = self.generate('tall', 'economic')
- for col in range(r.shape[1]):
- q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
- a1 = np.delete(a, col, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_economic_p_col(self):
- a, q, r = self.generate('tall', 'economic')
- for ndel in range(2, 6):
- for col in range(r.shape[1]-ndel):
- q1, r1 = qr_delete(q, r, col, ndel, which='col',
- overwrite_qr=False)
- a1 = np.delete(a, slice(col, col+ndel), 1)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_Mx1_1_row(self):
- a, q, r = self.generate('Mx1')
- for row in range(r.shape[0]):
- q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
- a1 = np.delete(a, row, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_Mx1_p_row(self):
- a, q, r = self.generate('Mx1')
- for ndel in range(2, 6):
- for row in range(a.shape[0]-ndel):
- q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
- a1 = np.delete(a, slice(row, row+ndel), 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1xN_1_col(self):
- a, q, r = self.generate('1xN')
- for col in range(r.shape[1]):
- q1, r1 = qr_delete(q, r, col, which='col', overwrite_qr=False)
- a1 = np.delete(a, col, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1xN_p_col(self):
- a, q, r = self.generate('1xN')
- for ndel in range(2, 6):
- for col in range(r.shape[1]-ndel):
- q1, r1 = qr_delete(q, r, col, ndel, which='col',
- overwrite_qr=False)
- a1 = np.delete(a, slice(col, col+ndel), 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_Mx1_economic_1_row(self):
- a, q, r = self.generate('Mx1', 'economic')
- for row in range(r.shape[0]):
- q1, r1 = qr_delete(q, r, row, overwrite_qr=False)
- a1 = np.delete(a, row, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_Mx1_economic_p_row(self):
- a, q, r = self.generate('Mx1', 'economic')
- for ndel in range(2, 6):
- for row in range(a.shape[0]-ndel):
- q1, r1 = qr_delete(q, r, row, ndel, overwrite_qr=False)
- a1 = np.delete(a, slice(row, row+ndel), 0)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_delete_last_1_row(self):
- # full and eco are the same for 1xN
- a, q, r = self.generate('1xN')
- q1, r1 = qr_delete(q, r, 0, 1, 'row')
- assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
- assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))
- def test_delete_last_p_row(self):
- a, q, r = self.generate('tall', 'full')
- q1, r1 = qr_delete(q, r, 0, a.shape[0], 'row')
- assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
- assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))
- a, q, r = self.generate('tall', 'economic')
- q1, r1 = qr_delete(q, r, 0, a.shape[0], 'row')
- assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
- assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))
- def test_delete_last_1_col(self):
- a, q, r = self.generate('Mx1', 'economic')
- q1, r1 = qr_delete(q, r, 0, 1, 'col')
- assert_equal(q1, np.ndarray(shape=(q.shape[0], 0), dtype=q.dtype))
- assert_equal(r1, np.ndarray(shape=(0, 0), dtype=r.dtype))
- a, q, r = self.generate('Mx1', 'full')
- q1, r1 = qr_delete(q, r, 0, 1, 'col')
- assert_unitary(q1)
- assert_(q1.dtype == q.dtype)
- assert_(q1.shape == q.shape)
- assert_equal(r1, np.ndarray(shape=(r.shape[0], 0), dtype=r.dtype))
- def test_delete_last_p_col(self):
- a, q, r = self.generate('tall', 'full')
- q1, r1 = qr_delete(q, r, 0, a.shape[1], 'col')
- assert_unitary(q1)
- assert_(q1.dtype == q.dtype)
- assert_(q1.shape == q.shape)
- assert_equal(r1, np.ndarray(shape=(r.shape[0], 0), dtype=r.dtype))
- a, q, r = self.generate('tall', 'economic')
- q1, r1 = qr_delete(q, r, 0, a.shape[1], 'col')
- assert_equal(q1, np.ndarray(shape=(q.shape[0], 0), dtype=q.dtype))
- assert_equal(r1, np.ndarray(shape=(0, 0), dtype=r.dtype))
- def test_delete_1x1_row_col(self):
- a, q, r = self.generate('1x1')
- q1, r1 = qr_delete(q, r, 0, 1, 'row')
- assert_equal(q1, np.ndarray(shape=(0, 0), dtype=q.dtype))
- assert_equal(r1, np.ndarray(shape=(0, r.shape[1]), dtype=r.dtype))
- a, q, r = self.generate('1x1')
- q1, r1 = qr_delete(q, r, 0, 1, 'col')
- assert_unitary(q1)
- assert_(q1.dtype == q.dtype)
- assert_(q1.shape == q.shape)
- assert_equal(r1, np.ndarray(shape=(r.shape[0], 0), dtype=r.dtype))
- # all full qr, row deletes and single column deletes should be able to
- # handle any non negative strides. (only row and column vector
- # operations are used.) p column delete require fortran ordered
- # Q and R and will make a copy as necessary. Economic qr row deletes
- # requre a contigous q.
- def base_non_simple_strides(self, adjust_strides, ks, p, which,
- overwriteable):
- if which == 'row':
- qind = (slice(p,None), slice(p,None))
- rind = (slice(p,None), slice(None))
- else:
- qind = (slice(None), slice(None))
- rind = (slice(None), slice(None,-p))
- for type, k in itertools.product(['sqr', 'tall', 'fat'], ks):
- a, q0, r0, = self.generate(type)
- qs, rs = adjust_strides((q0, r0))
- if p == 1:
- a1 = np.delete(a, k, 0 if which == 'row' else 1)
- else:
- s = slice(k,k+p)
- if k < 0:
- s = slice(k, k + p +
- (a.shape[0] if which == 'row' else a.shape[1]))
- a1 = np.delete(a, s, 0 if which == 'row' else 1)
- # for each variable, q, r we try with it strided and
- # overwrite=False. Then we try with overwrite=True, and make
- # sure that q and r are still overwritten.
- q = q0.copy('F')
- r = r0.copy('F')
- q1, r1 = qr_delete(qs, r, k, p, which, False)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- q1o, r1o = qr_delete(qs, r, k, p, which, True)
- check_qr(q1o, r1o, a1, self.rtol, self.atol)
- if overwriteable:
- assert_allclose(q1o, qs[qind], rtol=self.rtol, atol=self.atol)
- assert_allclose(r1o, r[rind], rtol=self.rtol, atol=self.atol)
- q = q0.copy('F')
- r = r0.copy('F')
- q2, r2 = qr_delete(q, rs, k, p, which, False)
- check_qr(q2, r2, a1, self.rtol, self.atol)
- q2o, r2o = qr_delete(q, rs, k, p, which, True)
- check_qr(q2o, r2o, a1, self.rtol, self.atol)
- if overwriteable:
- assert_allclose(q2o, q[qind], rtol=self.rtol, atol=self.atol)
- assert_allclose(r2o, rs[rind], rtol=self.rtol, atol=self.atol)
- q = q0.copy('F')
- r = r0.copy('F')
- # since some of these were consumed above
- qs, rs = adjust_strides((q, r))
- q3, r3 = qr_delete(qs, rs, k, p, which, False)
- check_qr(q3, r3, a1, self.rtol, self.atol)
- q3o, r3o = qr_delete(qs, rs, k, p, which, True)
- check_qr(q3o, r3o, a1, self.rtol, self.atol)
- if overwriteable:
- assert_allclose(q2o, qs[qind], rtol=self.rtol, atol=self.atol)
- assert_allclose(r3o, rs[rind], rtol=self.rtol, atol=self.atol)
- def test_non_unit_strides_1_row(self):
- self.base_non_simple_strides(make_strided, [0], 1, 'row', True)
- def test_non_unit_strides_p_row(self):
- self.base_non_simple_strides(make_strided, [0], 3, 'row', True)
- def test_non_unit_strides_1_col(self):
- self.base_non_simple_strides(make_strided, [0], 1, 'col', True)
- def test_non_unit_strides_p_col(self):
- self.base_non_simple_strides(make_strided, [0], 3, 'col', False)
- def test_neg_strides_1_row(self):
- self.base_non_simple_strides(negate_strides, [0], 1, 'row', False)
- def test_neg_strides_p_row(self):
- self.base_non_simple_strides(negate_strides, [0], 3, 'row', False)
- def test_neg_strides_1_col(self):
- self.base_non_simple_strides(negate_strides, [0], 1, 'col', False)
- def test_neg_strides_p_col(self):
- self.base_non_simple_strides(negate_strides, [0], 3, 'col', False)
- def test_non_itemize_strides_1_row(self):
- self.base_non_simple_strides(nonitemsize_strides, [0], 1, 'row', False)
- def test_non_itemize_strides_p_row(self):
- self.base_non_simple_strides(nonitemsize_strides, [0], 3, 'row', False)
- def test_non_itemize_strides_1_col(self):
- self.base_non_simple_strides(nonitemsize_strides, [0], 1, 'col', False)
- def test_non_itemize_strides_p_col(self):
- self.base_non_simple_strides(nonitemsize_strides, [0], 3, 'col', False)
- def test_non_native_byte_order_1_row(self):
- self.base_non_simple_strides(make_nonnative, [0], 1, 'row', False)
- def test_non_native_byte_order_p_row(self):
- self.base_non_simple_strides(make_nonnative, [0], 3, 'row', False)
- def test_non_native_byte_order_1_col(self):
- self.base_non_simple_strides(make_nonnative, [0], 1, 'col', False)
- def test_non_native_byte_order_p_col(self):
- self.base_non_simple_strides(make_nonnative, [0], 3, 'col', False)
- def test_neg_k(self):
- a, q, r = self.generate('sqr')
- for k, p, w in itertools.product([-3, -7], [1, 3], ['row', 'col']):
- q1, r1 = qr_delete(q, r, k, p, w, overwrite_qr=False)
- if w == 'row':
- a1 = np.delete(a, slice(k+a.shape[0], k+p+a.shape[0]), 0)
- else:
- a1 = np.delete(a, slice(k+a.shape[0], k+p+a.shape[1]), 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def base_overwrite_qr(self, which, p, test_C, test_F, mode='full'):
- assert_sqr = True if mode == 'full' else False
- if which == 'row':
- qind = (slice(p,None), slice(p,None))
- rind = (slice(p,None), slice(None))
- else:
- qind = (slice(None), slice(None))
- rind = (slice(None), slice(None,-p))
- a, q0, r0 = self.generate('sqr', mode)
- if p == 1:
- a1 = np.delete(a, 3, 0 if which == 'row' else 1)
- else:
- a1 = np.delete(a, slice(3, 3+p), 0 if which == 'row' else 1)
- # don't overwrite
- q = q0.copy('F')
- r = r0.copy('F')
- q1, r1 = qr_delete(q, r, 3, p, which, False)
- check_qr(q1, r1, a1, self.rtol, self.atol, assert_sqr)
- check_qr(q, r, a, self.rtol, self.atol, assert_sqr)
- if test_F:
- q = q0.copy('F')
- r = r0.copy('F')
- q2, r2 = qr_delete(q, r, 3, p, which, True)
- check_qr(q2, r2, a1, self.rtol, self.atol, assert_sqr)
- # verify the overwriting
- assert_allclose(q2, q[qind], rtol=self.rtol, atol=self.atol)
- assert_allclose(r2, r[rind], rtol=self.rtol, atol=self.atol)
- if test_C:
- q = q0.copy('C')
- r = r0.copy('C')
- q3, r3 = qr_delete(q, r, 3, p, which, True)
- check_qr(q3, r3, a1, self.rtol, self.atol, assert_sqr)
- assert_allclose(q3, q[qind], rtol=self.rtol, atol=self.atol)
- assert_allclose(r3, r[rind], rtol=self.rtol, atol=self.atol)
- def test_overwrite_qr_1_row(self):
- # any positively strided q and r.
- self.base_overwrite_qr('row', 1, True, True)
- def test_overwrite_economic_qr_1_row(self):
- # Any contiguous q and positively strided r.
- self.base_overwrite_qr('row', 1, True, True, 'economic')
- def test_overwrite_qr_1_col(self):
- # any positively strided q and r.
- # full and eco share code paths
- self.base_overwrite_qr('col', 1, True, True)
- def test_overwrite_qr_p_row(self):
- # any positively strided q and r.
- self.base_overwrite_qr('row', 3, True, True)
- def test_overwrite_economic_qr_p_row(self):
- # any contiguous q and positively strided r
- self.base_overwrite_qr('row', 3, True, True, 'economic')
- def test_overwrite_qr_p_col(self):
- # only F orderd q and r can be overwritten for cols
- # full and eco share code paths
- self.base_overwrite_qr('col', 3, False, True)
- def test_bad_which(self):
- a, q, r = self.generate('sqr')
- assert_raises(ValueError, qr_delete, q, r, 0, which='foo')
- def test_bad_k(self):
- a, q, r = self.generate('tall')
- assert_raises(ValueError, qr_delete, q, r, q.shape[0], 1)
- assert_raises(ValueError, qr_delete, q, r, -q.shape[0]-1, 1)
- assert_raises(ValueError, qr_delete, q, r, r.shape[0], 1, 'col')
- assert_raises(ValueError, qr_delete, q, r, -r.shape[0]-1, 1, 'col')
- def test_bad_p(self):
- a, q, r = self.generate('tall')
- # p must be positive
- assert_raises(ValueError, qr_delete, q, r, 0, -1)
- assert_raises(ValueError, qr_delete, q, r, 0, -1, 'col')
- # and nonzero
- assert_raises(ValueError, qr_delete, q, r, 0, 0)
- assert_raises(ValueError, qr_delete, q, r, 0, 0, 'col')
- # must have at least k+p rows or cols, depending.
- assert_raises(ValueError, qr_delete, q, r, 3, q.shape[0]-2)
- assert_raises(ValueError, qr_delete, q, r, 3, r.shape[1]-2, 'col')
- def test_empty_q(self):
- a, q, r = self.generate('tall')
- # same code path for 'row' and 'col'
- assert_raises(ValueError, qr_delete, np.array([]), r, 0, 1)
- def test_empty_r(self):
- a, q, r = self.generate('tall')
- # same code path for 'row' and 'col'
- assert_raises(ValueError, qr_delete, q, np.array([]), 0, 1)
- def test_mismatched_q_and_r(self):
- a, q, r = self.generate('tall')
- r = r[1:]
- assert_raises(ValueError, qr_delete, q, r, 0, 1)
- def test_unsupported_dtypes(self):
- dts = ['int8', 'int16', 'int32', 'int64',
- 'uint8', 'uint16', 'uint32', 'uint64',
- 'float16', 'longdouble', 'longcomplex',
- 'bool']
- a, q0, r0 = self.generate('tall')
- for dtype in dts:
- q = q0.real.astype(dtype)
- with np.errstate(invalid="ignore"):
- r = r0.real.astype(dtype)
- assert_raises(ValueError, qr_delete, q, r0, 0, 1, 'row')
- assert_raises(ValueError, qr_delete, q, r0, 0, 2, 'row')
- assert_raises(ValueError, qr_delete, q, r0, 0, 1, 'col')
- assert_raises(ValueError, qr_delete, q, r0, 0, 2, 'col')
- assert_raises(ValueError, qr_delete, q0, r, 0, 1, 'row')
- assert_raises(ValueError, qr_delete, q0, r, 0, 2, 'row')
- assert_raises(ValueError, qr_delete, q0, r, 0, 1, 'col')
- assert_raises(ValueError, qr_delete, q0, r, 0, 2, 'col')
- def test_check_finite(self):
- a0, q0, r0 = self.generate('tall')
- q = q0.copy('F')
- q[1,1] = np.nan
- assert_raises(ValueError, qr_delete, q, r0, 0, 1, 'row')
- assert_raises(ValueError, qr_delete, q, r0, 0, 3, 'row')
- assert_raises(ValueError, qr_delete, q, r0, 0, 1, 'col')
- assert_raises(ValueError, qr_delete, q, r0, 0, 3, 'col')
- r = r0.copy('F')
- r[1,1] = np.nan
- assert_raises(ValueError, qr_delete, q0, r, 0, 1, 'row')
- assert_raises(ValueError, qr_delete, q0, r, 0, 3, 'row')
- assert_raises(ValueError, qr_delete, q0, r, 0, 1, 'col')
- assert_raises(ValueError, qr_delete, q0, r, 0, 3, 'col')
- def test_qr_scalar(self):
- a, q, r = self.generate('1x1')
- assert_raises(ValueError, qr_delete, q[0, 0], r, 0, 1, 'row')
- assert_raises(ValueError, qr_delete, q, r[0, 0], 0, 1, 'row')
- assert_raises(ValueError, qr_delete, q[0, 0], r, 0, 1, 'col')
- assert_raises(ValueError, qr_delete, q, r[0, 0], 0, 1, 'col')
- class TestQRdelete_f(BaseQRdelete):
- dtype = np.dtype('f')
- class TestQRdelete_F(BaseQRdelete):
- dtype = np.dtype('F')
- class TestQRdelete_d(BaseQRdelete):
- dtype = np.dtype('d')
- class TestQRdelete_D(BaseQRdelete):
- dtype = np.dtype('D')
- class BaseQRinsert(BaseQRdeltas):
- def generate(self, type, mode='full', which='row', p=1):
- a, q, r = super().generate(type, mode)
- assert_(p > 0)
- # super call set the seed...
- if which == 'row':
- if p == 1:
- u = np.random.random(a.shape[1])
- else:
- u = np.random.random((p, a.shape[1]))
- elif which == 'col':
- if p == 1:
- u = np.random.random(a.shape[0])
- else:
- u = np.random.random((a.shape[0], p))
- else:
- ValueError('which should be either "row" or "col"')
- if np.iscomplexobj(self.dtype.type(1)):
- b = np.random.random(u.shape)
- u = u + 1j * b
- u = u.astype(self.dtype)
- return a, q, r, u
- def test_sqr_1_row(self):
- a, q, r, u = self.generate('sqr', which='row')
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, row, u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_sqr_p_row(self):
- # sqr + rows --> fat always
- a, q, r, u = self.generate('sqr', which='row', p=3)
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, np.full(3, row, np.intp), u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_sqr_1_col(self):
- a, q, r, u = self.generate('sqr', which='col')
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, col, u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_sqr_p_col(self):
- # sqr + cols --> fat always
- a, q, r, u = self.generate('sqr', which='col', p=3)
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, np.full(3, col, np.intp), u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_tall_1_row(self):
- a, q, r, u = self.generate('tall', which='row')
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, row, u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_tall_p_row(self):
- # tall + rows --> tall always
- a, q, r, u = self.generate('tall', which='row', p=3)
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, np.full(3, row, np.intp), u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_tall_1_col(self):
- a, q, r, u = self.generate('tall', which='col')
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, col, u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- # for column adds to tall matrices there are three cases to test
- # tall + pcol --> tall
- # tall + pcol --> sqr
- # tall + pcol --> fat
- def base_tall_p_col_xxx(self, p):
- a, q, r, u = self.generate('tall', which='col', p=p)
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, np.full(p, col, np.intp), u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_tall_p_col_tall(self):
- # 12x7 + 12x3 = 12x10 --> stays tall
- self.base_tall_p_col_xxx(3)
- def test_tall_p_col_sqr(self):
- # 12x7 + 12x5 = 12x12 --> becomes sqr
- self.base_tall_p_col_xxx(5)
- def test_tall_p_col_fat(self):
- # 12x7 + 12x7 = 12x14 --> becomes fat
- self.base_tall_p_col_xxx(7)
- def test_fat_1_row(self):
- a, q, r, u = self.generate('fat', which='row')
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, row, u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- # for row adds to fat matrices there are three cases to test
- # fat + prow --> fat
- # fat + prow --> sqr
- # fat + prow --> tall
- def base_fat_p_row_xxx(self, p):
- a, q, r, u = self.generate('fat', which='row', p=p)
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, np.full(p, row, np.intp), u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_fat_p_row_fat(self):
- # 7x12 + 3x12 = 10x12 --> stays fat
- self.base_fat_p_row_xxx(3)
- def test_fat_p_row_sqr(self):
- # 7x12 + 5x12 = 12x12 --> becomes sqr
- self.base_fat_p_row_xxx(5)
- def test_fat_p_row_tall(self):
- # 7x12 + 7x12 = 14x12 --> becomes tall
- self.base_fat_p_row_xxx(7)
- def test_fat_1_col(self):
- a, q, r, u = self.generate('fat', which='col')
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, col, u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_fat_p_col(self):
- # fat + cols --> fat always
- a, q, r, u = self.generate('fat', which='col', p=3)
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, np.full(3, col, np.intp), u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_economic_1_row(self):
- a, q, r, u = self.generate('tall', 'economic', 'row')
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row, overwrite_qru=False)
- a1 = np.insert(a, row, u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_economic_p_row(self):
- # tall + rows --> tall always
- a, q, r, u = self.generate('tall', 'economic', 'row', 3)
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row, overwrite_qru=False)
- a1 = np.insert(a, np.full(3, row, np.intp), u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_economic_1_col(self):
- a, q, r, u = self.generate('tall', 'economic', which='col')
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u.copy(), col, 'col', overwrite_qru=False)
- a1 = np.insert(a, col, u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_economic_1_col_bad_update(self):
- # When the column to be added lies in the span of Q, the update is
- # not meaningful. This is detected, and a LinAlgError is issued.
- q = np.eye(5, 3, dtype=self.dtype)
- r = np.eye(3, dtype=self.dtype)
- u = np.array([1, 0, 0, 0, 0], self.dtype)
- assert_raises(linalg.LinAlgError, qr_insert, q, r, u, 0, 'col')
- # for column adds to economic matrices there are three cases to test
- # eco + pcol --> eco
- # eco + pcol --> sqr
- # eco + pcol --> fat
- def base_economic_p_col_xxx(self, p):
- a, q, r, u = self.generate('tall', 'economic', which='col', p=p)
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, np.full(p, col, np.intp), u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_economic_p_col_eco(self):
- # 12x7 + 12x3 = 12x10 --> stays eco
- self.base_economic_p_col_xxx(3)
- def test_economic_p_col_sqr(self):
- # 12x7 + 12x5 = 12x12 --> becomes sqr
- self.base_economic_p_col_xxx(5)
- def test_economic_p_col_fat(self):
- # 12x7 + 12x7 = 12x14 --> becomes fat
- self.base_economic_p_col_xxx(7)
- def test_Mx1_1_row(self):
- a, q, r, u = self.generate('Mx1', which='row')
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, row, u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_Mx1_p_row(self):
- a, q, r, u = self.generate('Mx1', which='row', p=3)
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, np.full(3, row, np.intp), u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_Mx1_1_col(self):
- a, q, r, u = self.generate('Mx1', which='col')
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, col, u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_Mx1_p_col(self):
- a, q, r, u = self.generate('Mx1', which='col', p=3)
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, np.full(3, col, np.intp), u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_Mx1_economic_1_row(self):
- a, q, r, u = self.generate('Mx1', 'economic', 'row')
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, row, u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_Mx1_economic_p_row(self):
- a, q, r, u = self.generate('Mx1', 'economic', 'row', 3)
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, np.full(3, row, np.intp), u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_Mx1_economic_1_col(self):
- a, q, r, u = self.generate('Mx1', 'economic', 'col')
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, col, u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_Mx1_economic_p_col(self):
- a, q, r, u = self.generate('Mx1', 'economic', 'col', 3)
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, np.full(3, col, np.intp), u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_1xN_1_row(self):
- a, q, r, u = self.generate('1xN', which='row')
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, row, u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1xN_p_row(self):
- a, q, r, u = self.generate('1xN', which='row', p=3)
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, np.full(3, row, np.intp), u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1xN_1_col(self):
- a, q, r, u = self.generate('1xN', which='col')
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, col, u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1xN_p_col(self):
- a, q, r, u = self.generate('1xN', which='col', p=3)
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, np.full(3, col, np.intp), u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1x1_1_row(self):
- a, q, r, u = self.generate('1x1', which='row')
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, row, u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1x1_p_row(self):
- a, q, r, u = self.generate('1x1', which='row', p=3)
- for row in range(r.shape[0] + 1):
- q1, r1 = qr_insert(q, r, u, row)
- a1 = np.insert(a, np.full(3, row, np.intp), u, 0)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1x1_1_col(self):
- a, q, r, u = self.generate('1x1', which='col')
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, col, u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1x1_p_col(self):
- a, q, r, u = self.generate('1x1', which='col', p=3)
- for col in range(r.shape[1] + 1):
- q1, r1 = qr_insert(q, r, u, col, 'col', overwrite_qru=False)
- a1 = np.insert(a, np.full(3, col, np.intp), u, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1x1_1_scalar(self):
- a, q, r, u = self.generate('1x1', which='row')
- assert_raises(ValueError, qr_insert, q[0, 0], r, u, 0, 'row')
- assert_raises(ValueError, qr_insert, q, r[0, 0], u, 0, 'row')
- assert_raises(ValueError, qr_insert, q, r, u[0], 0, 'row')
- assert_raises(ValueError, qr_insert, q[0, 0], r, u, 0, 'col')
- assert_raises(ValueError, qr_insert, q, r[0, 0], u, 0, 'col')
- assert_raises(ValueError, qr_insert, q, r, u[0], 0, 'col')
- def base_non_simple_strides(self, adjust_strides, k, p, which):
- for type in ['sqr', 'tall', 'fat']:
- a, q0, r0, u0 = self.generate(type, which=which, p=p)
- qs, rs, us = adjust_strides((q0, r0, u0))
- if p == 1:
- ai = np.insert(a, k, u0, 0 if which == 'row' else 1)
- else:
- ai = np.insert(a, np.full(p, k, np.intp),
- u0 if which == 'row' else u0,
- 0 if which == 'row' else 1)
- # for each variable, q, r, u we try with it strided and
- # overwrite=False. Then we try with overwrite=True. Nothing
- # is checked to see if it can be overwritten, since only
- # F ordered Q can be overwritten when adding columns.
- q = q0.copy('F')
- r = r0.copy('F')
- u = u0.copy('F')
- q1, r1 = qr_insert(qs, r, u, k, which, overwrite_qru=False)
- check_qr(q1, r1, ai, self.rtol, self.atol)
- q1o, r1o = qr_insert(qs, r, u, k, which, overwrite_qru=True)
- check_qr(q1o, r1o, ai, self.rtol, self.atol)
- q = q0.copy('F')
- r = r0.copy('F')
- u = u0.copy('F')
- q2, r2 = qr_insert(q, rs, u, k, which, overwrite_qru=False)
- check_qr(q2, r2, ai, self.rtol, self.atol)
- q2o, r2o = qr_insert(q, rs, u, k, which, overwrite_qru=True)
- check_qr(q2o, r2o, ai, self.rtol, self.atol)
- q = q0.copy('F')
- r = r0.copy('F')
- u = u0.copy('F')
- q3, r3 = qr_insert(q, r, us, k, which, overwrite_qru=False)
- check_qr(q3, r3, ai, self.rtol, self.atol)
- q3o, r3o = qr_insert(q, r, us, k, which, overwrite_qru=True)
- check_qr(q3o, r3o, ai, self.rtol, self.atol)
- q = q0.copy('F')
- r = r0.copy('F')
- u = u0.copy('F')
- # since some of these were consumed above
- qs, rs, us = adjust_strides((q, r, u))
- q5, r5 = qr_insert(qs, rs, us, k, which, overwrite_qru=False)
- check_qr(q5, r5, ai, self.rtol, self.atol)
- q5o, r5o = qr_insert(qs, rs, us, k, which, overwrite_qru=True)
- check_qr(q5o, r5o, ai, self.rtol, self.atol)
- def test_non_unit_strides_1_row(self):
- self.base_non_simple_strides(make_strided, 0, 1, 'row')
- def test_non_unit_strides_p_row(self):
- self.base_non_simple_strides(make_strided, 0, 3, 'row')
- def test_non_unit_strides_1_col(self):
- self.base_non_simple_strides(make_strided, 0, 1, 'col')
- def test_non_unit_strides_p_col(self):
- self.base_non_simple_strides(make_strided, 0, 3, 'col')
- def test_neg_strides_1_row(self):
- self.base_non_simple_strides(negate_strides, 0, 1, 'row')
- def test_neg_strides_p_row(self):
- self.base_non_simple_strides(negate_strides, 0, 3, 'row')
- def test_neg_strides_1_col(self):
- self.base_non_simple_strides(negate_strides, 0, 1, 'col')
- def test_neg_strides_p_col(self):
- self.base_non_simple_strides(negate_strides, 0, 3, 'col')
- def test_non_itemsize_strides_1_row(self):
- self.base_non_simple_strides(nonitemsize_strides, 0, 1, 'row')
- def test_non_itemsize_strides_p_row(self):
- self.base_non_simple_strides(nonitemsize_strides, 0, 3, 'row')
- def test_non_itemsize_strides_1_col(self):
- self.base_non_simple_strides(nonitemsize_strides, 0, 1, 'col')
- def test_non_itemsize_strides_p_col(self):
- self.base_non_simple_strides(nonitemsize_strides, 0, 3, 'col')
- def test_non_native_byte_order_1_row(self):
- self.base_non_simple_strides(make_nonnative, 0, 1, 'row')
- def test_non_native_byte_order_p_row(self):
- self.base_non_simple_strides(make_nonnative, 0, 3, 'row')
- def test_non_native_byte_order_1_col(self):
- self.base_non_simple_strides(make_nonnative, 0, 1, 'col')
- def test_non_native_byte_order_p_col(self):
- self.base_non_simple_strides(make_nonnative, 0, 3, 'col')
- def test_overwrite_qu_rank_1(self):
- # when inserting rows, the size of both Q and R change, so only
- # column inserts can overwrite q. Only complex column inserts
- # with C ordered Q overwrite u. Any contiguous Q is overwritten
- # when inserting 1 column
- a, q0, r, u, = self.generate('sqr', which='col', p=1)
- q = q0.copy('C')
- u0 = u.copy()
- # don't overwrite
- q1, r1 = qr_insert(q, r, u, 0, 'col', overwrite_qru=False)
- a1 = np.insert(a, 0, u0, 1)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- check_qr(q, r, a, self.rtol, self.atol)
- # try overwriting
- q2, r2 = qr_insert(q, r, u, 0, 'col', overwrite_qru=True)
- check_qr(q2, r2, a1, self.rtol, self.atol)
- # verify the overwriting
- assert_allclose(q2, q, rtol=self.rtol, atol=self.atol)
- assert_allclose(u, u0.conj(), self.rtol, self.atol)
- # now try with a fortran ordered Q
- qF = q0.copy('F')
- u1 = u0.copy()
- q3, r3 = qr_insert(qF, r, u1, 0, 'col', overwrite_qru=False)
- check_qr(q3, r3, a1, self.rtol, self.atol)
- check_qr(qF, r, a, self.rtol, self.atol)
- # try overwriting
- q4, r4 = qr_insert(qF, r, u1, 0, 'col', overwrite_qru=True)
- check_qr(q4, r4, a1, self.rtol, self.atol)
- assert_allclose(q4, qF, rtol=self.rtol, atol=self.atol)
- def test_overwrite_qu_rank_p(self):
- # when inserting rows, the size of both Q and R change, so only
- # column inserts can potentially overwrite Q. In practice, only
- # F ordered Q are overwritten with a rank p update.
- a, q0, r, u, = self.generate('sqr', which='col', p=3)
- q = q0.copy('F')
- a1 = np.insert(a, np.zeros(3, np.intp), u, 1)
- # don't overwrite
- q1, r1 = qr_insert(q, r, u, 0, 'col', overwrite_qru=False)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- check_qr(q, r, a, self.rtol, self.atol)
- # try overwriting
- q2, r2 = qr_insert(q, r, u, 0, 'col', overwrite_qru=True)
- check_qr(q2, r2, a1, self.rtol, self.atol)
- assert_allclose(q2, q, rtol=self.rtol, atol=self.atol)
- def test_empty_inputs(self):
- a, q, r, u = self.generate('sqr', which='row')
- assert_raises(ValueError, qr_insert, np.array([]), r, u, 0, 'row')
- assert_raises(ValueError, qr_insert, q, np.array([]), u, 0, 'row')
- assert_raises(ValueError, qr_insert, q, r, np.array([]), 0, 'row')
- assert_raises(ValueError, qr_insert, np.array([]), r, u, 0, 'col')
- assert_raises(ValueError, qr_insert, q, np.array([]), u, 0, 'col')
- assert_raises(ValueError, qr_insert, q, r, np.array([]), 0, 'col')
- def test_mismatched_shapes(self):
- a, q, r, u = self.generate('tall', which='row')
- assert_raises(ValueError, qr_insert, q, r[1:], u, 0, 'row')
- assert_raises(ValueError, qr_insert, q[:-2], r, u, 0, 'row')
- assert_raises(ValueError, qr_insert, q, r, u[1:], 0, 'row')
- assert_raises(ValueError, qr_insert, q, r[1:], u, 0, 'col')
- assert_raises(ValueError, qr_insert, q[:-2], r, u, 0, 'col')
- assert_raises(ValueError, qr_insert, q, r, u[1:], 0, 'col')
- def test_unsupported_dtypes(self):
- dts = ['int8', 'int16', 'int32', 'int64',
- 'uint8', 'uint16', 'uint32', 'uint64',
- 'float16', 'longdouble', 'longcomplex',
- 'bool']
- a, q0, r0, u0 = self.generate('sqr', which='row')
- for dtype in dts:
- q = q0.real.astype(dtype)
- with np.errstate(invalid="ignore"):
- r = r0.real.astype(dtype)
- u = u0.real.astype(dtype)
- assert_raises(ValueError, qr_insert, q, r0, u0, 0, 'row')
- assert_raises(ValueError, qr_insert, q, r0, u0, 0, 'col')
- assert_raises(ValueError, qr_insert, q0, r, u0, 0, 'row')
- assert_raises(ValueError, qr_insert, q0, r, u0, 0, 'col')
- assert_raises(ValueError, qr_insert, q0, r0, u, 0, 'row')
- assert_raises(ValueError, qr_insert, q0, r0, u, 0, 'col')
- def test_check_finite(self):
- a0, q0, r0, u0 = self.generate('sqr', which='row', p=3)
- q = q0.copy('F')
- q[1,1] = np.nan
- assert_raises(ValueError, qr_insert, q, r0, u0[:,0], 0, 'row')
- assert_raises(ValueError, qr_insert, q, r0, u0, 0, 'row')
- assert_raises(ValueError, qr_insert, q, r0, u0[:,0], 0, 'col')
- assert_raises(ValueError, qr_insert, q, r0, u0, 0, 'col')
- r = r0.copy('F')
- r[1,1] = np.nan
- assert_raises(ValueError, qr_insert, q0, r, u0[:,0], 0, 'row')
- assert_raises(ValueError, qr_insert, q0, r, u0, 0, 'row')
- assert_raises(ValueError, qr_insert, q0, r, u0[:,0], 0, 'col')
- assert_raises(ValueError, qr_insert, q0, r, u0, 0, 'col')
- u = u0.copy('F')
- u[0,0] = np.nan
- assert_raises(ValueError, qr_insert, q0, r0, u[:,0], 0, 'row')
- assert_raises(ValueError, qr_insert, q0, r0, u, 0, 'row')
- assert_raises(ValueError, qr_insert, q0, r0, u[:,0], 0, 'col')
- assert_raises(ValueError, qr_insert, q0, r0, u, 0, 'col')
- class TestQRinsert_f(BaseQRinsert):
- dtype = np.dtype('f')
- class TestQRinsert_F(BaseQRinsert):
- dtype = np.dtype('F')
- class TestQRinsert_d(BaseQRinsert):
- dtype = np.dtype('d')
- class TestQRinsert_D(BaseQRinsert):
- dtype = np.dtype('D')
- class BaseQRupdate(BaseQRdeltas):
- def generate(self, type, mode='full', p=1):
- a, q, r = super().generate(type, mode)
- # super call set the seed...
- if p == 1:
- u = np.random.random(q.shape[0])
- v = np.random.random(r.shape[1])
- else:
- u = np.random.random((q.shape[0], p))
- v = np.random.random((r.shape[1], p))
- if np.iscomplexobj(self.dtype.type(1)):
- b = np.random.random(u.shape)
- u = u + 1j * b
- c = np.random.random(v.shape)
- v = v + 1j * c
- u = u.astype(self.dtype)
- v = v.astype(self.dtype)
- return a, q, r, u, v
- def test_sqr_rank_1(self):
- a, q, r, u, v = self.generate('sqr')
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.outer(u, v.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_sqr_rank_p(self):
- # test ndim = 2, rank 1 updates here too
- for p in [1, 2, 3, 5]:
- a, q, r, u, v = self.generate('sqr', p=p)
- if p == 1:
- u = u.reshape(u.size, 1)
- v = v.reshape(v.size, 1)
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.dot(u, v.T.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_tall_rank_1(self):
- a, q, r, u, v = self.generate('tall')
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.outer(u, v.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_tall_rank_p(self):
- for p in [1, 2, 3, 5]:
- a, q, r, u, v = self.generate('tall', p=p)
- if p == 1:
- u = u.reshape(u.size, 1)
- v = v.reshape(v.size, 1)
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.dot(u, v.T.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_fat_rank_1(self):
- a, q, r, u, v = self.generate('fat')
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.outer(u, v.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_fat_rank_p(self):
- for p in [1, 2, 3, 5]:
- a, q, r, u, v = self.generate('fat', p=p)
- if p == 1:
- u = u.reshape(u.size, 1)
- v = v.reshape(v.size, 1)
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.dot(u, v.T.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_economic_rank_1(self):
- a, q, r, u, v = self.generate('tall', 'economic')
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.outer(u, v.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_economic_rank_p(self):
- for p in [1, 2, 3, 5]:
- a, q, r, u, v = self.generate('tall', 'economic', p)
- if p == 1:
- u = u.reshape(u.size, 1)
- v = v.reshape(v.size, 1)
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.dot(u, v.T.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_Mx1_rank_1(self):
- a, q, r, u, v = self.generate('Mx1')
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.outer(u, v.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_Mx1_rank_p(self):
- # when M or N == 1, only a rank 1 update is allowed. This isn't
- # fundamental limitation, but the code does not support it.
- a, q, r, u, v = self.generate('Mx1', p=1)
- u = u.reshape(u.size, 1)
- v = v.reshape(v.size, 1)
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.dot(u, v.T.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_Mx1_economic_rank_1(self):
- a, q, r, u, v = self.generate('Mx1', 'economic')
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.outer(u, v.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_Mx1_economic_rank_p(self):
- # when M or N == 1, only a rank 1 update is allowed. This isn't
- # fundamental limitation, but the code does not support it.
- a, q, r, u, v = self.generate('Mx1', 'economic', p=1)
- u = u.reshape(u.size, 1)
- v = v.reshape(v.size, 1)
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.dot(u, v.T.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- def test_1xN_rank_1(self):
- a, q, r, u, v = self.generate('1xN')
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.outer(u, v.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1xN_rank_p(self):
- # when M or N == 1, only a rank 1 update is allowed. This isn't
- # fundamental limitation, but the code does not support it.
- a, q, r, u, v = self.generate('1xN', p=1)
- u = u.reshape(u.size, 1)
- v = v.reshape(v.size, 1)
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.dot(u, v.T.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1x1_rank_1(self):
- a, q, r, u, v = self.generate('1x1')
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.outer(u, v.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1x1_rank_p(self):
- # when M or N == 1, only a rank 1 update is allowed. This isn't
- # fundamental limitation, but the code does not support it.
- a, q, r, u, v = self.generate('1x1', p=1)
- u = u.reshape(u.size, 1)
- v = v.reshape(v.size, 1)
- q1, r1 = qr_update(q, r, u, v, False)
- a1 = a + np.dot(u, v.T.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol)
- def test_1x1_rank_1_scalar(self):
- a, q, r, u, v = self.generate('1x1')
- assert_raises(ValueError, qr_update, q[0, 0], r, u, v)
- assert_raises(ValueError, qr_update, q, r[0, 0], u, v)
- assert_raises(ValueError, qr_update, q, r, u[0], v)
- assert_raises(ValueError, qr_update, q, r, u, v[0])
- def base_non_simple_strides(self, adjust_strides, mode, p, overwriteable):
- assert_sqr = False if mode == 'economic' else True
- for type in ['sqr', 'tall', 'fat']:
- a, q0, r0, u0, v0 = self.generate(type, mode, p)
- qs, rs, us, vs = adjust_strides((q0, r0, u0, v0))
- if p == 1:
- aup = a + np.outer(u0, v0.conj())
- else:
- aup = a + np.dot(u0, v0.T.conj())
- # for each variable, q, r, u, v we try with it strided and
- # overwrite=False. Then we try with overwrite=True, and make
- # sure that if p == 1, r and v are still overwritten.
- # a strided q and u must always be copied.
- q = q0.copy('F')
- r = r0.copy('F')
- u = u0.copy('F')
- v = v0.copy('C')
- q1, r1 = qr_update(qs, r, u, v, False)
- check_qr(q1, r1, aup, self.rtol, self.atol, assert_sqr)
- q1o, r1o = qr_update(qs, r, u, v, True)
- check_qr(q1o, r1o, aup, self.rtol, self.atol, assert_sqr)
- if overwriteable:
- assert_allclose(r1o, r, rtol=self.rtol, atol=self.atol)
- assert_allclose(v, v0.conj(), rtol=self.rtol, atol=self.atol)
- q = q0.copy('F')
- r = r0.copy('F')
- u = u0.copy('F')
- v = v0.copy('C')
- q2, r2 = qr_update(q, rs, u, v, False)
- check_qr(q2, r2, aup, self.rtol, self.atol, assert_sqr)
- q2o, r2o = qr_update(q, rs, u, v, True)
- check_qr(q2o, r2o, aup, self.rtol, self.atol, assert_sqr)
- if overwriteable:
- assert_allclose(r2o, rs, rtol=self.rtol, atol=self.atol)
- assert_allclose(v, v0.conj(), rtol=self.rtol, atol=self.atol)
- q = q0.copy('F')
- r = r0.copy('F')
- u = u0.copy('F')
- v = v0.copy('C')
- q3, r3 = qr_update(q, r, us, v, False)
- check_qr(q3, r3, aup, self.rtol, self.atol, assert_sqr)
- q3o, r3o = qr_update(q, r, us, v, True)
- check_qr(q3o, r3o, aup, self.rtol, self.atol, assert_sqr)
- if overwriteable:
- assert_allclose(r3o, r, rtol=self.rtol, atol=self.atol)
- assert_allclose(v, v0.conj(), rtol=self.rtol, atol=self.atol)
- q = q0.copy('F')
- r = r0.copy('F')
- u = u0.copy('F')
- v = v0.copy('C')
- q4, r4 = qr_update(q, r, u, vs, False)
- check_qr(q4, r4, aup, self.rtol, self.atol, assert_sqr)
- q4o, r4o = qr_update(q, r, u, vs, True)
- check_qr(q4o, r4o, aup, self.rtol, self.atol, assert_sqr)
- if overwriteable:
- assert_allclose(r4o, r, rtol=self.rtol, atol=self.atol)
- assert_allclose(vs, v0.conj(), rtol=self.rtol, atol=self.atol)
- q = q0.copy('F')
- r = r0.copy('F')
- u = u0.copy('F')
- v = v0.copy('C')
- # since some of these were consumed above
- qs, rs, us, vs = adjust_strides((q, r, u, v))
- q5, r5 = qr_update(qs, rs, us, vs, False)
- check_qr(q5, r5, aup, self.rtol, self.atol, assert_sqr)
- q5o, r5o = qr_update(qs, rs, us, vs, True)
- check_qr(q5o, r5o, aup, self.rtol, self.atol, assert_sqr)
- if overwriteable:
- assert_allclose(r5o, rs, rtol=self.rtol, atol=self.atol)
- assert_allclose(vs, v0.conj(), rtol=self.rtol, atol=self.atol)
- def test_non_unit_strides_rank_1(self):
- self.base_non_simple_strides(make_strided, 'full', 1, True)
- def test_non_unit_strides_economic_rank_1(self):
- self.base_non_simple_strides(make_strided, 'economic', 1, True)
- def test_non_unit_strides_rank_p(self):
- self.base_non_simple_strides(make_strided, 'full', 3, False)
- def test_non_unit_strides_economic_rank_p(self):
- self.base_non_simple_strides(make_strided, 'economic', 3, False)
- def test_neg_strides_rank_1(self):
- self.base_non_simple_strides(negate_strides, 'full', 1, False)
- def test_neg_strides_economic_rank_1(self):
- self.base_non_simple_strides(negate_strides, 'economic', 1, False)
- def test_neg_strides_rank_p(self):
- self.base_non_simple_strides(negate_strides, 'full', 3, False)
- def test_neg_strides_economic_rank_p(self):
- self.base_non_simple_strides(negate_strides, 'economic', 3, False)
- def test_non_itemsize_strides_rank_1(self):
- self.base_non_simple_strides(nonitemsize_strides, 'full', 1, False)
- def test_non_itemsize_strides_economic_rank_1(self):
- self.base_non_simple_strides(nonitemsize_strides, 'economic', 1, False)
- def test_non_itemsize_strides_rank_p(self):
- self.base_non_simple_strides(nonitemsize_strides, 'full', 3, False)
- def test_non_itemsize_strides_economic_rank_p(self):
- self.base_non_simple_strides(nonitemsize_strides, 'economic', 3, False)
- def test_non_native_byte_order_rank_1(self):
- self.base_non_simple_strides(make_nonnative, 'full', 1, False)
- def test_non_native_byte_order_economic_rank_1(self):
- self.base_non_simple_strides(make_nonnative, 'economic', 1, False)
- def test_non_native_byte_order_rank_p(self):
- self.base_non_simple_strides(make_nonnative, 'full', 3, False)
- def test_non_native_byte_order_economic_rank_p(self):
- self.base_non_simple_strides(make_nonnative, 'economic', 3, False)
- def test_overwrite_qruv_rank_1(self):
- # Any positive strided q, r, u, and v can be overwritten for a rank 1
- # update, only checking C and F contiguous.
- a, q0, r0, u0, v0 = self.generate('sqr')
- a1 = a + np.outer(u0, v0.conj())
- q = q0.copy('F')
- r = r0.copy('F')
- u = u0.copy('F')
- v = v0.copy('F')
- # don't overwrite
- q1, r1 = qr_update(q, r, u, v, False)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- check_qr(q, r, a, self.rtol, self.atol)
- q2, r2 = qr_update(q, r, u, v, True)
- check_qr(q2, r2, a1, self.rtol, self.atol)
- # verify the overwriting, no good way to check u and v.
- assert_allclose(q2, q, rtol=self.rtol, atol=self.atol)
- assert_allclose(r2, r, rtol=self.rtol, atol=self.atol)
- q = q0.copy('C')
- r = r0.copy('C')
- u = u0.copy('C')
- v = v0.copy('C')
- q3, r3 = qr_update(q, r, u, v, True)
- check_qr(q3, r3, a1, self.rtol, self.atol)
- assert_allclose(q3, q, rtol=self.rtol, atol=self.atol)
- assert_allclose(r3, r, rtol=self.rtol, atol=self.atol)
- def test_overwrite_qruv_rank_1_economic(self):
- # updating economic decompositions can overwrite any contigous r,
- # and positively strided r and u. V is only ever read.
- # only checking C and F contiguous.
- a, q0, r0, u0, v0 = self.generate('tall', 'economic')
- a1 = a + np.outer(u0, v0.conj())
- q = q0.copy('F')
- r = r0.copy('F')
- u = u0.copy('F')
- v = v0.copy('F')
- # don't overwrite
- q1, r1 = qr_update(q, r, u, v, False)
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- check_qr(q, r, a, self.rtol, self.atol, False)
- q2, r2 = qr_update(q, r, u, v, True)
- check_qr(q2, r2, a1, self.rtol, self.atol, False)
- # verify the overwriting, no good way to check u and v.
- assert_allclose(q2, q, rtol=self.rtol, atol=self.atol)
- assert_allclose(r2, r, rtol=self.rtol, atol=self.atol)
- q = q0.copy('C')
- r = r0.copy('C')
- u = u0.copy('C')
- v = v0.copy('C')
- q3, r3 = qr_update(q, r, u, v, True)
- check_qr(q3, r3, a1, self.rtol, self.atol, False)
- assert_allclose(q3, q, rtol=self.rtol, atol=self.atol)
- assert_allclose(r3, r, rtol=self.rtol, atol=self.atol)
- def test_overwrite_qruv_rank_p(self):
- # for rank p updates, q r must be F contiguous, v must be C (v.T --> F)
- # and u can be C or F, but is only overwritten if Q is C and complex
- a, q0, r0, u0, v0 = self.generate('sqr', p=3)
- a1 = a + np.dot(u0, v0.T.conj())
- q = q0.copy('F')
- r = r0.copy('F')
- u = u0.copy('F')
- v = v0.copy('C')
- # don't overwrite
- q1, r1 = qr_update(q, r, u, v, False)
- check_qr(q1, r1, a1, self.rtol, self.atol)
- check_qr(q, r, a, self.rtol, self.atol)
- q2, r2 = qr_update(q, r, u, v, True)
- check_qr(q2, r2, a1, self.rtol, self.atol)
- # verify the overwriting, no good way to check u and v.
- assert_allclose(q2, q, rtol=self.rtol, atol=self.atol)
- assert_allclose(r2, r, rtol=self.rtol, atol=self.atol)
- def test_empty_inputs(self):
- a, q, r, u, v = self.generate('tall')
- assert_raises(ValueError, qr_update, np.array([]), r, u, v)
- assert_raises(ValueError, qr_update, q, np.array([]), u, v)
- assert_raises(ValueError, qr_update, q, r, np.array([]), v)
- assert_raises(ValueError, qr_update, q, r, u, np.array([]))
- def test_mismatched_shapes(self):
- a, q, r, u, v = self.generate('tall')
- assert_raises(ValueError, qr_update, q, r[1:], u, v)
- assert_raises(ValueError, qr_update, q[:-2], r, u, v)
- assert_raises(ValueError, qr_update, q, r, u[1:], v)
- assert_raises(ValueError, qr_update, q, r, u, v[1:])
- def test_unsupported_dtypes(self):
- dts = ['int8', 'int16', 'int32', 'int64',
- 'uint8', 'uint16', 'uint32', 'uint64',
- 'float16', 'longdouble', 'longcomplex',
- 'bool']
- a, q0, r0, u0, v0 = self.generate('tall')
- for dtype in dts:
- q = q0.real.astype(dtype)
- with np.errstate(invalid="ignore"):
- r = r0.real.astype(dtype)
- u = u0.real.astype(dtype)
- v = v0.real.astype(dtype)
- assert_raises(ValueError, qr_update, q, r0, u0, v0)
- assert_raises(ValueError, qr_update, q0, r, u0, v0)
- assert_raises(ValueError, qr_update, q0, r0, u, v0)
- assert_raises(ValueError, qr_update, q0, r0, u0, v)
- def test_integer_input(self):
- q = np.arange(16).reshape(4, 4)
- r = q.copy() # doesn't matter
- u = q[:, 0].copy()
- v = r[0, :].copy()
- assert_raises(ValueError, qr_update, q, r, u, v)
- def test_check_finite(self):
- a0, q0, r0, u0, v0 = self.generate('tall', p=3)
- q = q0.copy('F')
- q[1,1] = np.nan
- assert_raises(ValueError, qr_update, q, r0, u0[:,0], v0[:,0])
- assert_raises(ValueError, qr_update, q, r0, u0, v0)
- r = r0.copy('F')
- r[1,1] = np.nan
- assert_raises(ValueError, qr_update, q0, r, u0[:,0], v0[:,0])
- assert_raises(ValueError, qr_update, q0, r, u0, v0)
- u = u0.copy('F')
- u[0,0] = np.nan
- assert_raises(ValueError, qr_update, q0, r0, u[:,0], v0[:,0])
- assert_raises(ValueError, qr_update, q0, r0, u, v0)
- v = v0.copy('F')
- v[0,0] = np.nan
- assert_raises(ValueError, qr_update, q0, r0, u[:,0], v[:,0])
- assert_raises(ValueError, qr_update, q0, r0, u, v)
- def test_economic_check_finite(self):
- a0, q0, r0, u0, v0 = self.generate('tall', mode='economic', p=3)
- q = q0.copy('F')
- q[1,1] = np.nan
- assert_raises(ValueError, qr_update, q, r0, u0[:,0], v0[:,0])
- assert_raises(ValueError, qr_update, q, r0, u0, v0)
- r = r0.copy('F')
- r[1,1] = np.nan
- assert_raises(ValueError, qr_update, q0, r, u0[:,0], v0[:,0])
- assert_raises(ValueError, qr_update, q0, r, u0, v0)
- u = u0.copy('F')
- u[0,0] = np.nan
- assert_raises(ValueError, qr_update, q0, r0, u[:,0], v0[:,0])
- assert_raises(ValueError, qr_update, q0, r0, u, v0)
- v = v0.copy('F')
- v[0,0] = np.nan
- assert_raises(ValueError, qr_update, q0, r0, u[:,0], v[:,0])
- assert_raises(ValueError, qr_update, q0, r0, u, v)
- def test_u_exactly_in_span_q(self):
- q = np.array([[0, 0], [0, 0], [1, 0], [0, 1]], self.dtype)
- r = np.array([[1, 0], [0, 1]], self.dtype)
- u = np.array([0, 0, 0, -1], self.dtype)
- v = np.array([1, 2], self.dtype)
- q1, r1 = qr_update(q, r, u, v)
- a1 = np.dot(q, r) + np.outer(u, v.conj())
- check_qr(q1, r1, a1, self.rtol, self.atol, False)
- class TestQRupdate_f(BaseQRupdate):
- dtype = np.dtype('f')
- class TestQRupdate_F(BaseQRupdate):
- dtype = np.dtype('F')
- class TestQRupdate_d(BaseQRupdate):
- dtype = np.dtype('d')
- class TestQRupdate_D(BaseQRupdate):
- dtype = np.dtype('D')
- def test_form_qTu():
- # We want to ensure that all of the code paths through this function are
- # tested. Most of them should be hit with the rest of test suite, but
- # explicit tests make clear precisely what is being tested.
- #
- # This function expects that Q is either C or F contiguous and square.
- # Economic mode decompositions (Q is (M, N), M != N) do not go through this
- # function. U may have any positive strides.
- #
- # Some of these test are duplicates, since contiguous 1d arrays are both C
- # and F.
- q_order = ['F', 'C']
- q_shape = [(8, 8), ]
- u_order = ['F', 'C', 'A'] # here A means is not F not C
- u_shape = [1, 3]
- dtype = ['f', 'd', 'F', 'D']
- for qo, qs, uo, us, d in \
- itertools.product(q_order, q_shape, u_order, u_shape, dtype):
- if us == 1:
- check_form_qTu(qo, qs, uo, us, 1, d)
- check_form_qTu(qo, qs, uo, us, 2, d)
- else:
- check_form_qTu(qo, qs, uo, us, 2, d)
- def check_form_qTu(q_order, q_shape, u_order, u_shape, u_ndim, dtype):
- np.random.seed(47)
- if u_shape == 1 and u_ndim == 1:
- u_shape = (q_shape[0],)
- else:
- u_shape = (q_shape[0], u_shape)
- dtype = np.dtype(dtype)
- if dtype.char in 'fd':
- q = np.random.random(q_shape)
- u = np.random.random(u_shape)
- elif dtype.char in 'FD':
- q = np.random.random(q_shape) + 1j*np.random.random(q_shape)
- u = np.random.random(u_shape) + 1j*np.random.random(u_shape)
- else:
- ValueError("form_qTu doesn't support this dtype")
- q = np.require(q, dtype, q_order)
- if u_order != 'A':
- u = np.require(u, dtype, u_order)
- else:
- u, = make_strided((u.astype(dtype),))
- rtol = 10.0 ** -(np.finfo(dtype).precision-2)
- atol = 2*np.finfo(dtype).eps
- expected = np.dot(q.T.conj(), u)
- res = _decomp_update._form_qTu(q, u)
- assert_allclose(res, expected, rtol=rtol, atol=atol)
|