test_hermite.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555
  1. """Tests for hermite module.
  2. """
  3. from functools import reduce
  4. import numpy as np
  5. import numpy.polynomial.hermite as herm
  6. from numpy.polynomial.polynomial import polyval
  7. from numpy.testing import (
  8. assert_almost_equal, assert_raises, assert_equal, assert_,
  9. )
  10. H0 = np.array([1])
  11. H1 = np.array([0, 2])
  12. H2 = np.array([-2, 0, 4])
  13. H3 = np.array([0, -12, 0, 8])
  14. H4 = np.array([12, 0, -48, 0, 16])
  15. H5 = np.array([0, 120, 0, -160, 0, 32])
  16. H6 = np.array([-120, 0, 720, 0, -480, 0, 64])
  17. H7 = np.array([0, -1680, 0, 3360, 0, -1344, 0, 128])
  18. H8 = np.array([1680, 0, -13440, 0, 13440, 0, -3584, 0, 256])
  19. H9 = np.array([0, 30240, 0, -80640, 0, 48384, 0, -9216, 0, 512])
  20. Hlist = [H0, H1, H2, H3, H4, H5, H6, H7, H8, H9]
  21. def trim(x):
  22. return herm.hermtrim(x, tol=1e-6)
  23. class TestConstants:
  24. def test_hermdomain(self):
  25. assert_equal(herm.hermdomain, [-1, 1])
  26. def test_hermzero(self):
  27. assert_equal(herm.hermzero, [0])
  28. def test_hermone(self):
  29. assert_equal(herm.hermone, [1])
  30. def test_hermx(self):
  31. assert_equal(herm.hermx, [0, .5])
  32. class TestArithmetic:
  33. x = np.linspace(-3, 3, 100)
  34. def test_hermadd(self):
  35. for i in range(5):
  36. for j in range(5):
  37. msg = f"At i={i}, j={j}"
  38. tgt = np.zeros(max(i, j) + 1)
  39. tgt[i] += 1
  40. tgt[j] += 1
  41. res = herm.hermadd([0]*i + [1], [0]*j + [1])
  42. assert_equal(trim(res), trim(tgt), err_msg=msg)
  43. def test_hermsub(self):
  44. for i in range(5):
  45. for j in range(5):
  46. msg = f"At i={i}, j={j}"
  47. tgt = np.zeros(max(i, j) + 1)
  48. tgt[i] += 1
  49. tgt[j] -= 1
  50. res = herm.hermsub([0]*i + [1], [0]*j + [1])
  51. assert_equal(trim(res), trim(tgt), err_msg=msg)
  52. def test_hermmulx(self):
  53. assert_equal(herm.hermmulx([0]), [0])
  54. assert_equal(herm.hermmulx([1]), [0, .5])
  55. for i in range(1, 5):
  56. ser = [0]*i + [1]
  57. tgt = [0]*(i - 1) + [i, 0, .5]
  58. assert_equal(herm.hermmulx(ser), tgt)
  59. def test_hermmul(self):
  60. # check values of result
  61. for i in range(5):
  62. pol1 = [0]*i + [1]
  63. val1 = herm.hermval(self.x, pol1)
  64. for j in range(5):
  65. msg = f"At i={i}, j={j}"
  66. pol2 = [0]*j + [1]
  67. val2 = herm.hermval(self.x, pol2)
  68. pol3 = herm.hermmul(pol1, pol2)
  69. val3 = herm.hermval(self.x, pol3)
  70. assert_(len(pol3) == i + j + 1, msg)
  71. assert_almost_equal(val3, val1*val2, err_msg=msg)
  72. def test_hermdiv(self):
  73. for i in range(5):
  74. for j in range(5):
  75. msg = f"At i={i}, j={j}"
  76. ci = [0]*i + [1]
  77. cj = [0]*j + [1]
  78. tgt = herm.hermadd(ci, cj)
  79. quo, rem = herm.hermdiv(tgt, ci)
  80. res = herm.hermadd(herm.hermmul(quo, ci), rem)
  81. assert_equal(trim(res), trim(tgt), err_msg=msg)
  82. def test_hermpow(self):
  83. for i in range(5):
  84. for j in range(5):
  85. msg = f"At i={i}, j={j}"
  86. c = np.arange(i + 1)
  87. tgt = reduce(herm.hermmul, [c]*j, np.array([1]))
  88. res = herm.hermpow(c, j)
  89. assert_equal(trim(res), trim(tgt), err_msg=msg)
  90. class TestEvaluation:
  91. # coefficients of 1 + 2*x + 3*x**2
  92. c1d = np.array([2.5, 1., .75])
  93. c2d = np.einsum('i,j->ij', c1d, c1d)
  94. c3d = np.einsum('i,j,k->ijk', c1d, c1d, c1d)
  95. # some random values in [-1, 1)
  96. x = np.random.random((3, 5))*2 - 1
  97. y = polyval(x, [1., 2., 3.])
  98. def test_hermval(self):
  99. #check empty input
  100. assert_equal(herm.hermval([], [1]).size, 0)
  101. #check normal input)
  102. x = np.linspace(-1, 1)
  103. y = [polyval(x, c) for c in Hlist]
  104. for i in range(10):
  105. msg = f"At i={i}"
  106. tgt = y[i]
  107. res = herm.hermval(x, [0]*i + [1])
  108. assert_almost_equal(res, tgt, err_msg=msg)
  109. #check that shape is preserved
  110. for i in range(3):
  111. dims = [2]*i
  112. x = np.zeros(dims)
  113. assert_equal(herm.hermval(x, [1]).shape, dims)
  114. assert_equal(herm.hermval(x, [1, 0]).shape, dims)
  115. assert_equal(herm.hermval(x, [1, 0, 0]).shape, dims)
  116. def test_hermval2d(self):
  117. x1, x2, x3 = self.x
  118. y1, y2, y3 = self.y
  119. #test exceptions
  120. assert_raises(ValueError, herm.hermval2d, x1, x2[:2], self.c2d)
  121. #test values
  122. tgt = y1*y2
  123. res = herm.hermval2d(x1, x2, self.c2d)
  124. assert_almost_equal(res, tgt)
  125. #test shape
  126. z = np.ones((2, 3))
  127. res = herm.hermval2d(z, z, self.c2d)
  128. assert_(res.shape == (2, 3))
  129. def test_hermval3d(self):
  130. x1, x2, x3 = self.x
  131. y1, y2, y3 = self.y
  132. #test exceptions
  133. assert_raises(ValueError, herm.hermval3d, x1, x2, x3[:2], self.c3d)
  134. #test values
  135. tgt = y1*y2*y3
  136. res = herm.hermval3d(x1, x2, x3, self.c3d)
  137. assert_almost_equal(res, tgt)
  138. #test shape
  139. z = np.ones((2, 3))
  140. res = herm.hermval3d(z, z, z, self.c3d)
  141. assert_(res.shape == (2, 3))
  142. def test_hermgrid2d(self):
  143. x1, x2, x3 = self.x
  144. y1, y2, y3 = self.y
  145. #test values
  146. tgt = np.einsum('i,j->ij', y1, y2)
  147. res = herm.hermgrid2d(x1, x2, self.c2d)
  148. assert_almost_equal(res, tgt)
  149. #test shape
  150. z = np.ones((2, 3))
  151. res = herm.hermgrid2d(z, z, self.c2d)
  152. assert_(res.shape == (2, 3)*2)
  153. def test_hermgrid3d(self):
  154. x1, x2, x3 = self.x
  155. y1, y2, y3 = self.y
  156. #test values
  157. tgt = np.einsum('i,j,k->ijk', y1, y2, y3)
  158. res = herm.hermgrid3d(x1, x2, x3, self.c3d)
  159. assert_almost_equal(res, tgt)
  160. #test shape
  161. z = np.ones((2, 3))
  162. res = herm.hermgrid3d(z, z, z, self.c3d)
  163. assert_(res.shape == (2, 3)*3)
  164. class TestIntegral:
  165. def test_hermint(self):
  166. # check exceptions
  167. assert_raises(TypeError, herm.hermint, [0], .5)
  168. assert_raises(ValueError, herm.hermint, [0], -1)
  169. assert_raises(ValueError, herm.hermint, [0], 1, [0, 0])
  170. assert_raises(ValueError, herm.hermint, [0], lbnd=[0])
  171. assert_raises(ValueError, herm.hermint, [0], scl=[0])
  172. assert_raises(TypeError, herm.hermint, [0], axis=.5)
  173. # test integration of zero polynomial
  174. for i in range(2, 5):
  175. k = [0]*(i - 2) + [1]
  176. res = herm.hermint([0], m=i, k=k)
  177. assert_almost_equal(res, [0, .5])
  178. # check single integration with integration constant
  179. for i in range(5):
  180. scl = i + 1
  181. pol = [0]*i + [1]
  182. tgt = [i] + [0]*i + [1/scl]
  183. hermpol = herm.poly2herm(pol)
  184. hermint = herm.hermint(hermpol, m=1, k=[i])
  185. res = herm.herm2poly(hermint)
  186. assert_almost_equal(trim(res), trim(tgt))
  187. # check single integration with integration constant and lbnd
  188. for i in range(5):
  189. scl = i + 1
  190. pol = [0]*i + [1]
  191. hermpol = herm.poly2herm(pol)
  192. hermint = herm.hermint(hermpol, m=1, k=[i], lbnd=-1)
  193. assert_almost_equal(herm.hermval(-1, hermint), i)
  194. # check single integration with integration constant and scaling
  195. for i in range(5):
  196. scl = i + 1
  197. pol = [0]*i + [1]
  198. tgt = [i] + [0]*i + [2/scl]
  199. hermpol = herm.poly2herm(pol)
  200. hermint = herm.hermint(hermpol, m=1, k=[i], scl=2)
  201. res = herm.herm2poly(hermint)
  202. assert_almost_equal(trim(res), trim(tgt))
  203. # check multiple integrations with default k
  204. for i in range(5):
  205. for j in range(2, 5):
  206. pol = [0]*i + [1]
  207. tgt = pol[:]
  208. for k in range(j):
  209. tgt = herm.hermint(tgt, m=1)
  210. res = herm.hermint(pol, m=j)
  211. assert_almost_equal(trim(res), trim(tgt))
  212. # check multiple integrations with defined k
  213. for i in range(5):
  214. for j in range(2, 5):
  215. pol = [0]*i + [1]
  216. tgt = pol[:]
  217. for k in range(j):
  218. tgt = herm.hermint(tgt, m=1, k=[k])
  219. res = herm.hermint(pol, m=j, k=list(range(j)))
  220. assert_almost_equal(trim(res), trim(tgt))
  221. # check multiple integrations with lbnd
  222. for i in range(5):
  223. for j in range(2, 5):
  224. pol = [0]*i + [1]
  225. tgt = pol[:]
  226. for k in range(j):
  227. tgt = herm.hermint(tgt, m=1, k=[k], lbnd=-1)
  228. res = herm.hermint(pol, m=j, k=list(range(j)), lbnd=-1)
  229. assert_almost_equal(trim(res), trim(tgt))
  230. # check multiple integrations with scaling
  231. for i in range(5):
  232. for j in range(2, 5):
  233. pol = [0]*i + [1]
  234. tgt = pol[:]
  235. for k in range(j):
  236. tgt = herm.hermint(tgt, m=1, k=[k], scl=2)
  237. res = herm.hermint(pol, m=j, k=list(range(j)), scl=2)
  238. assert_almost_equal(trim(res), trim(tgt))
  239. def test_hermint_axis(self):
  240. # check that axis keyword works
  241. c2d = np.random.random((3, 4))
  242. tgt = np.vstack([herm.hermint(c) for c in c2d.T]).T
  243. res = herm.hermint(c2d, axis=0)
  244. assert_almost_equal(res, tgt)
  245. tgt = np.vstack([herm.hermint(c) for c in c2d])
  246. res = herm.hermint(c2d, axis=1)
  247. assert_almost_equal(res, tgt)
  248. tgt = np.vstack([herm.hermint(c, k=3) for c in c2d])
  249. res = herm.hermint(c2d, k=3, axis=1)
  250. assert_almost_equal(res, tgt)
  251. class TestDerivative:
  252. def test_hermder(self):
  253. # check exceptions
  254. assert_raises(TypeError, herm.hermder, [0], .5)
  255. assert_raises(ValueError, herm.hermder, [0], -1)
  256. # check that zeroth derivative does nothing
  257. for i in range(5):
  258. tgt = [0]*i + [1]
  259. res = herm.hermder(tgt, m=0)
  260. assert_equal(trim(res), trim(tgt))
  261. # check that derivation is the inverse of integration
  262. for i in range(5):
  263. for j in range(2, 5):
  264. tgt = [0]*i + [1]
  265. res = herm.hermder(herm.hermint(tgt, m=j), m=j)
  266. assert_almost_equal(trim(res), trim(tgt))
  267. # check derivation with scaling
  268. for i in range(5):
  269. for j in range(2, 5):
  270. tgt = [0]*i + [1]
  271. res = herm.hermder(herm.hermint(tgt, m=j, scl=2), m=j, scl=.5)
  272. assert_almost_equal(trim(res), trim(tgt))
  273. def test_hermder_axis(self):
  274. # check that axis keyword works
  275. c2d = np.random.random((3, 4))
  276. tgt = np.vstack([herm.hermder(c) for c in c2d.T]).T
  277. res = herm.hermder(c2d, axis=0)
  278. assert_almost_equal(res, tgt)
  279. tgt = np.vstack([herm.hermder(c) for c in c2d])
  280. res = herm.hermder(c2d, axis=1)
  281. assert_almost_equal(res, tgt)
  282. class TestVander:
  283. # some random values in [-1, 1)
  284. x = np.random.random((3, 5))*2 - 1
  285. def test_hermvander(self):
  286. # check for 1d x
  287. x = np.arange(3)
  288. v = herm.hermvander(x, 3)
  289. assert_(v.shape == (3, 4))
  290. for i in range(4):
  291. coef = [0]*i + [1]
  292. assert_almost_equal(v[..., i], herm.hermval(x, coef))
  293. # check for 2d x
  294. x = np.array([[1, 2], [3, 4], [5, 6]])
  295. v = herm.hermvander(x, 3)
  296. assert_(v.shape == (3, 2, 4))
  297. for i in range(4):
  298. coef = [0]*i + [1]
  299. assert_almost_equal(v[..., i], herm.hermval(x, coef))
  300. def test_hermvander2d(self):
  301. # also tests hermval2d for non-square coefficient array
  302. x1, x2, x3 = self.x
  303. c = np.random.random((2, 3))
  304. van = herm.hermvander2d(x1, x2, [1, 2])
  305. tgt = herm.hermval2d(x1, x2, c)
  306. res = np.dot(van, c.flat)
  307. assert_almost_equal(res, tgt)
  308. # check shape
  309. van = herm.hermvander2d([x1], [x2], [1, 2])
  310. assert_(van.shape == (1, 5, 6))
  311. def test_hermvander3d(self):
  312. # also tests hermval3d for non-square coefficient array
  313. x1, x2, x3 = self.x
  314. c = np.random.random((2, 3, 4))
  315. van = herm.hermvander3d(x1, x2, x3, [1, 2, 3])
  316. tgt = herm.hermval3d(x1, x2, x3, c)
  317. res = np.dot(van, c.flat)
  318. assert_almost_equal(res, tgt)
  319. # check shape
  320. van = herm.hermvander3d([x1], [x2], [x3], [1, 2, 3])
  321. assert_(van.shape == (1, 5, 24))
  322. class TestFitting:
  323. def test_hermfit(self):
  324. def f(x):
  325. return x*(x - 1)*(x - 2)
  326. def f2(x):
  327. return x**4 + x**2 + 1
  328. # Test exceptions
  329. assert_raises(ValueError, herm.hermfit, [1], [1], -1)
  330. assert_raises(TypeError, herm.hermfit, [[1]], [1], 0)
  331. assert_raises(TypeError, herm.hermfit, [], [1], 0)
  332. assert_raises(TypeError, herm.hermfit, [1], [[[1]]], 0)
  333. assert_raises(TypeError, herm.hermfit, [1, 2], [1], 0)
  334. assert_raises(TypeError, herm.hermfit, [1], [1, 2], 0)
  335. assert_raises(TypeError, herm.hermfit, [1], [1], 0, w=[[1]])
  336. assert_raises(TypeError, herm.hermfit, [1], [1], 0, w=[1, 1])
  337. assert_raises(ValueError, herm.hermfit, [1], [1], [-1,])
  338. assert_raises(ValueError, herm.hermfit, [1], [1], [2, -1, 6])
  339. assert_raises(TypeError, herm.hermfit, [1], [1], [])
  340. # Test fit
  341. x = np.linspace(0, 2)
  342. y = f(x)
  343. #
  344. coef3 = herm.hermfit(x, y, 3)
  345. assert_equal(len(coef3), 4)
  346. assert_almost_equal(herm.hermval(x, coef3), y)
  347. coef3 = herm.hermfit(x, y, [0, 1, 2, 3])
  348. assert_equal(len(coef3), 4)
  349. assert_almost_equal(herm.hermval(x, coef3), y)
  350. #
  351. coef4 = herm.hermfit(x, y, 4)
  352. assert_equal(len(coef4), 5)
  353. assert_almost_equal(herm.hermval(x, coef4), y)
  354. coef4 = herm.hermfit(x, y, [0, 1, 2, 3, 4])
  355. assert_equal(len(coef4), 5)
  356. assert_almost_equal(herm.hermval(x, coef4), y)
  357. # check things still work if deg is not in strict increasing
  358. coef4 = herm.hermfit(x, y, [2, 3, 4, 1, 0])
  359. assert_equal(len(coef4), 5)
  360. assert_almost_equal(herm.hermval(x, coef4), y)
  361. #
  362. coef2d = herm.hermfit(x, np.array([y, y]).T, 3)
  363. assert_almost_equal(coef2d, np.array([coef3, coef3]).T)
  364. coef2d = herm.hermfit(x, np.array([y, y]).T, [0, 1, 2, 3])
  365. assert_almost_equal(coef2d, np.array([coef3, coef3]).T)
  366. # test weighting
  367. w = np.zeros_like(x)
  368. yw = y.copy()
  369. w[1::2] = 1
  370. y[0::2] = 0
  371. wcoef3 = herm.hermfit(x, yw, 3, w=w)
  372. assert_almost_equal(wcoef3, coef3)
  373. wcoef3 = herm.hermfit(x, yw, [0, 1, 2, 3], w=w)
  374. assert_almost_equal(wcoef3, coef3)
  375. #
  376. wcoef2d = herm.hermfit(x, np.array([yw, yw]).T, 3, w=w)
  377. assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T)
  378. wcoef2d = herm.hermfit(x, np.array([yw, yw]).T, [0, 1, 2, 3], w=w)
  379. assert_almost_equal(wcoef2d, np.array([coef3, coef3]).T)
  380. # test scaling with complex values x points whose square
  381. # is zero when summed.
  382. x = [1, 1j, -1, -1j]
  383. assert_almost_equal(herm.hermfit(x, x, 1), [0, .5])
  384. assert_almost_equal(herm.hermfit(x, x, [0, 1]), [0, .5])
  385. # test fitting only even Legendre polynomials
  386. x = np.linspace(-1, 1)
  387. y = f2(x)
  388. coef1 = herm.hermfit(x, y, 4)
  389. assert_almost_equal(herm.hermval(x, coef1), y)
  390. coef2 = herm.hermfit(x, y, [0, 2, 4])
  391. assert_almost_equal(herm.hermval(x, coef2), y)
  392. assert_almost_equal(coef1, coef2)
  393. class TestCompanion:
  394. def test_raises(self):
  395. assert_raises(ValueError, herm.hermcompanion, [])
  396. assert_raises(ValueError, herm.hermcompanion, [1])
  397. def test_dimensions(self):
  398. for i in range(1, 5):
  399. coef = [0]*i + [1]
  400. assert_(herm.hermcompanion(coef).shape == (i, i))
  401. def test_linear_root(self):
  402. assert_(herm.hermcompanion([1, 2])[0, 0] == -.25)
  403. class TestGauss:
  404. def test_100(self):
  405. x, w = herm.hermgauss(100)
  406. # test orthogonality. Note that the results need to be normalized,
  407. # otherwise the huge values that can arise from fast growing
  408. # functions like Laguerre can be very confusing.
  409. v = herm.hermvander(x, 99)
  410. vv = np.dot(v.T * w, v)
  411. vd = 1/np.sqrt(vv.diagonal())
  412. vv = vd[:, None] * vv * vd
  413. assert_almost_equal(vv, np.eye(100))
  414. # check that the integral of 1 is correct
  415. tgt = np.sqrt(np.pi)
  416. assert_almost_equal(w.sum(), tgt)
  417. class TestMisc:
  418. def test_hermfromroots(self):
  419. res = herm.hermfromroots([])
  420. assert_almost_equal(trim(res), [1])
  421. for i in range(1, 5):
  422. roots = np.cos(np.linspace(-np.pi, 0, 2*i + 1)[1::2])
  423. pol = herm.hermfromroots(roots)
  424. res = herm.hermval(roots, pol)
  425. tgt = 0
  426. assert_(len(pol) == i + 1)
  427. assert_almost_equal(herm.herm2poly(pol)[-1], 1)
  428. assert_almost_equal(res, tgt)
  429. def test_hermroots(self):
  430. assert_almost_equal(herm.hermroots([1]), [])
  431. assert_almost_equal(herm.hermroots([1, 1]), [-.5])
  432. for i in range(2, 5):
  433. tgt = np.linspace(-1, 1, i)
  434. res = herm.hermroots(herm.hermfromroots(tgt))
  435. assert_almost_equal(trim(res), trim(tgt))
  436. def test_hermtrim(self):
  437. coef = [2, -1, 1, 0]
  438. # Test exceptions
  439. assert_raises(ValueError, herm.hermtrim, coef, -1)
  440. # Test results
  441. assert_equal(herm.hermtrim(coef), coef[:-1])
  442. assert_equal(herm.hermtrim(coef, 1), coef[:-3])
  443. assert_equal(herm.hermtrim(coef, 2), [0])
  444. def test_hermline(self):
  445. assert_equal(herm.hermline(3, 4), [3, 2])
  446. def test_herm2poly(self):
  447. for i in range(10):
  448. assert_almost_equal(herm.herm2poly([0]*i + [1]), Hlist[i])
  449. def test_poly2herm(self):
  450. for i in range(10):
  451. assert_almost_equal(herm.poly2herm(Hlist[i]), [0]*i + [1])
  452. def test_weight(self):
  453. x = np.linspace(-5, 5, 11)
  454. tgt = np.exp(-x**2)
  455. res = herm.hermweight(x)
  456. assert_almost_equal(res, tgt)