test_fit.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850
  1. import os
  2. import numpy as np
  3. import numpy.testing as npt
  4. from numpy.testing import assert_allclose, assert_equal
  5. import pytest
  6. from scipy import stats
  7. from scipy.optimize import differential_evolution
  8. from .test_continuous_basic import distcont
  9. from scipy.stats._distn_infrastructure import FitError
  10. from scipy.stats._distr_params import distdiscrete
  11. from scipy.stats import goodness_of_fit
  12. # this is not a proper statistical test for convergence, but only
  13. # verifies that the estimate and true values don't differ by too much
  14. fit_sizes = [1000, 5000, 10000] # sample sizes to try
  15. thresh_percent = 0.25 # percent of true parameters for fail cut-off
  16. thresh_min = 0.75 # minimum difference estimate - true to fail test
  17. mle_failing_fits = [
  18. 'burr',
  19. 'chi2',
  20. 'gausshyper',
  21. 'genexpon',
  22. 'gengamma',
  23. 'kappa4',
  24. 'ksone',
  25. 'kstwo',
  26. 'mielke',
  27. 'ncf',
  28. 'ncx2',
  29. 'pearson3',
  30. 'powerlognorm',
  31. 'truncexpon',
  32. 'truncpareto',
  33. 'tukeylambda',
  34. 'vonmises',
  35. 'levy_stable',
  36. 'trapezoid',
  37. 'truncweibull_min',
  38. 'studentized_range',
  39. ]
  40. mm_failing_fits = ['alpha', 'betaprime', 'burr', 'burr12', 'cauchy', 'chi',
  41. 'chi2', 'crystalball', 'dgamma', 'dweibull', 'f',
  42. 'fatiguelife', 'fisk', 'foldcauchy', 'genextreme',
  43. 'gengamma', 'genhyperbolic', 'gennorm', 'genpareto',
  44. 'halfcauchy', 'invgamma', 'invweibull', 'johnsonsu',
  45. 'kappa3', 'ksone', 'kstwo', 'levy', 'levy_l',
  46. 'levy_stable', 'loglaplace', 'lomax', 'mielke', 'nakagami',
  47. 'ncf', 'nct', 'ncx2', 'pareto', 'powerlognorm', 'powernorm',
  48. 'skewcauchy', 't', 'trapezoid', 'triang', 'truncpareto',
  49. 'truncweibull_min', 'tukeylambda', 'studentized_range']
  50. # not sure if these fail, but they caused my patience to fail
  51. mm_slow_fits = ['argus', 'exponpow', 'exponweib', 'gausshyper', 'genexpon',
  52. 'genhalflogistic', 'halfgennorm', 'gompertz', 'johnsonsb',
  53. 'kappa4', 'kstwobign', 'recipinvgauss',
  54. 'truncexpon', 'vonmises', 'vonmises_line']
  55. failing_fits = {"MM": mm_failing_fits + mm_slow_fits, "MLE": mle_failing_fits}
  56. # Don't run the fit test on these:
  57. skip_fit = [
  58. 'erlang', # Subclass of gamma, generates a warning.
  59. 'genhyperbolic', # too slow
  60. ]
  61. def cases_test_cont_fit():
  62. # this tests the closeness of the estimated parameters to the true
  63. # parameters with fit method of continuous distributions
  64. # Note: is slow, some distributions don't converge with sample
  65. # size <= 10000
  66. for distname, arg in distcont:
  67. if distname not in skip_fit:
  68. yield distname, arg
  69. @pytest.mark.slow
  70. @pytest.mark.parametrize('distname,arg', cases_test_cont_fit())
  71. @pytest.mark.parametrize('method', ["MLE", "MM"])
  72. def test_cont_fit(distname, arg, method):
  73. if distname in failing_fits[method]:
  74. # Skip failing fits unless overridden
  75. try:
  76. xfail = not int(os.environ['SCIPY_XFAIL'])
  77. except Exception:
  78. xfail = True
  79. if xfail:
  80. msg = "Fitting %s doesn't work reliably yet" % distname
  81. msg += (" [Set environment variable SCIPY_XFAIL=1 to run this"
  82. " test nevertheless.]")
  83. pytest.xfail(msg)
  84. distfn = getattr(stats, distname)
  85. truearg = np.hstack([arg, [0.0, 1.0]])
  86. diffthreshold = np.max(np.vstack([truearg*thresh_percent,
  87. np.full(distfn.numargs+2, thresh_min)]),
  88. 0)
  89. for fit_size in fit_sizes:
  90. # Note that if a fit succeeds, the other fit_sizes are skipped
  91. np.random.seed(1234)
  92. with np.errstate(all='ignore'):
  93. rvs = distfn.rvs(size=fit_size, *arg)
  94. est = distfn.fit(rvs, method=method) # start with default values
  95. diff = est - truearg
  96. # threshold for location
  97. diffthreshold[-2] = np.max([np.abs(rvs.mean())*thresh_percent,
  98. thresh_min])
  99. if np.any(np.isnan(est)):
  100. raise AssertionError('nan returned in fit')
  101. else:
  102. if np.all(np.abs(diff) <= diffthreshold):
  103. break
  104. else:
  105. txt = 'parameter: %s\n' % str(truearg)
  106. txt += 'estimated: %s\n' % str(est)
  107. txt += 'diff : %s\n' % str(diff)
  108. raise AssertionError('fit not very good in %s\n' % distfn.name + txt)
  109. def _check_loc_scale_mle_fit(name, data, desired, atol=None):
  110. d = getattr(stats, name)
  111. actual = d.fit(data)[-2:]
  112. assert_allclose(actual, desired, atol=atol,
  113. err_msg='poor mle fit of (loc, scale) in %s' % name)
  114. def test_non_default_loc_scale_mle_fit():
  115. data = np.array([1.01, 1.78, 1.78, 1.78, 1.88, 1.88, 1.88, 2.00])
  116. _check_loc_scale_mle_fit('uniform', data, [1.01, 0.99], 1e-3)
  117. _check_loc_scale_mle_fit('expon', data, [1.01, 0.73875], 1e-3)
  118. def test_expon_fit():
  119. """gh-6167"""
  120. data = [0, 0, 0, 0, 2, 2, 2, 2]
  121. phat = stats.expon.fit(data, floc=0)
  122. assert_allclose(phat, [0, 1.0], atol=1e-3)
  123. def test_fit_error():
  124. data = np.concatenate([np.zeros(29), np.ones(21)])
  125. message = "Optimization converged to parameters that are..."
  126. with pytest.raises(FitError, match=message), \
  127. pytest.warns(RuntimeWarning):
  128. stats.beta.fit(data)
  129. @pytest.mark.parametrize("dist, params",
  130. [(stats.norm, (0.5, 2.5)), # type: ignore[attr-defined] # noqa
  131. (stats.binom, (10, 0.3, 2))]) # type: ignore[attr-defined] # noqa
  132. def test_nnlf_and_related_methods(dist, params):
  133. rng = np.random.default_rng(983459824)
  134. if hasattr(dist, 'pdf'):
  135. logpxf = dist.logpdf
  136. else:
  137. logpxf = dist.logpmf
  138. x = dist.rvs(*params, size=100, random_state=rng)
  139. ref = -logpxf(x, *params).sum()
  140. res1 = dist.nnlf(params, x)
  141. res2 = dist._penalized_nnlf(params, x)
  142. assert_allclose(res1, ref)
  143. assert_allclose(res2, ref)
  144. def cases_test_fit_mle():
  145. # These fail default test or hang
  146. skip_basic_fit = {'argus', 'foldnorm', 'truncpareto', 'truncweibull_min',
  147. 'ksone', 'levy_stable', 'studentized_range', 'kstwo'}
  148. slow_basic_fit = {'burr12', 'johnsonsb', 'bradford', 'fisk', 'mielke',
  149. 'exponpow', 'rdist', 'norminvgauss', 'betaprime',
  150. 'powerlaw', 'pareto', 'johnsonsu', 'loglaplace',
  151. 'wrapcauchy', 'weibull_max', 'arcsine', 'binom', 'rice',
  152. 'uniform', 'f', 'invweibull', 'genpareto',
  153. 'nbinom', 'kappa3', 'lognorm', 'halfgennorm', 'pearson3',
  154. 'alpha', 't', 'crystalball', 'fatiguelife', 'nakagami',
  155. 'kstwobign', 'gompertz', 'dweibull', 'lomax', 'invgauss',
  156. 'recipinvgauss', 'chi', 'foldcauchy', 'powernorm',
  157. 'gennorm', 'randint', 'genextreme'}
  158. xslow_basic_fit = {'nchypergeom_fisher', 'nchypergeom_wallenius',
  159. 'gausshyper', 'genexpon', 'gengamma', 'genhyperbolic',
  160. 'geninvgauss', 'tukeylambda', 'skellam', 'ncx2',
  161. 'hypergeom', 'nhypergeom', 'zipfian', 'ncf',
  162. 'truncnorm', 'powerlognorm', 'beta',
  163. 'loguniform', 'reciprocal', 'trapezoid', 'nct',
  164. 'kappa4', 'betabinom', 'exponweib', 'genhalflogistic',
  165. 'burr', 'triang'}
  166. for dist in dict(distdiscrete + distcont):
  167. if dist in skip_basic_fit or not isinstance(dist, str):
  168. reason = "tested separately"
  169. yield pytest.param(dist, marks=pytest.mark.skip(reason=reason))
  170. elif dist in slow_basic_fit:
  171. reason = "too slow (>= 0.25s)"
  172. yield pytest.param(dist, marks=pytest.mark.slow(reason=reason))
  173. elif dist in xslow_basic_fit:
  174. reason = "too slow (>= 1.0s)"
  175. yield pytest.param(dist, marks=pytest.mark.xslow(reason=reason))
  176. else:
  177. yield dist
  178. def cases_test_fit_mse():
  179. # the first four are so slow that I'm not sure whether they would pass
  180. skip_basic_fit = {'levy_stable', 'studentized_range', 'ksone', 'skewnorm',
  181. 'norminvgauss', # super slow (~1 hr) but passes
  182. 'kstwo', # very slow (~25 min) but passes
  183. 'geninvgauss', # quite slow (~4 minutes) but passes
  184. 'gausshyper', 'genhyperbolic', # integration warnings
  185. 'argus', # close, but doesn't meet tolerance
  186. 'vonmises'} # can have negative CDF; doesn't play nice
  187. slow_basic_fit = {'wald', 'genextreme', 'anglit', 'semicircular',
  188. 'kstwobign', 'arcsine', 'genlogistic', 'truncexpon',
  189. 'fisk', 'uniform', 'exponnorm', 'maxwell', 'lomax',
  190. 'laplace_asymmetric', 'lognorm', 'foldcauchy',
  191. 'genpareto', 'powernorm', 'loglaplace', 'foldnorm',
  192. 'recipinvgauss', 'exponpow', 'bradford', 'weibull_max',
  193. 'gompertz', 'dweibull', 'truncpareto', 'weibull_min',
  194. 'johnsonsu', 'loggamma', 'kappa3', 'fatiguelife',
  195. 'pareto', 'invweibull', 'alpha', 'erlang', 'dgamma',
  196. 'chi2', 'crystalball', 'nakagami', 'truncweibull_min',
  197. 't', 'vonmises_line', 'triang', 'wrapcauchy', 'gamma',
  198. 'mielke', 'chi', 'johnsonsb', 'exponweib',
  199. 'genhalflogistic', 'randint', 'nhypergeom', 'hypergeom',
  200. 'betabinom'}
  201. xslow_basic_fit = {'burr', 'halfgennorm', 'invgamma',
  202. 'invgauss', 'powerlaw', 'burr12', 'trapezoid', 'kappa4',
  203. 'f', 'powerlognorm', 'ncx2', 'rdist', 'reciprocal',
  204. 'loguniform', 'betaprime', 'rice', 'gennorm',
  205. 'gengamma', 'truncnorm', 'ncf', 'nct', 'pearson3',
  206. 'beta', 'genexpon', 'tukeylambda', 'zipfian',
  207. 'nchypergeom_wallenius', 'nchypergeom_fisher'}
  208. warns_basic_fit = {'skellam'} # can remove mark after gh-14901 is resolved
  209. for dist in dict(distdiscrete + distcont):
  210. if dist in skip_basic_fit or not isinstance(dist, str):
  211. reason = "Fails. Oh well."
  212. yield pytest.param(dist, marks=pytest.mark.skip(reason=reason))
  213. elif dist in slow_basic_fit:
  214. reason = "too slow (>= 0.25s)"
  215. yield pytest.param(dist, marks=pytest.mark.slow(reason=reason))
  216. elif dist in xslow_basic_fit:
  217. reason = "too slow (>= 1.0s)"
  218. yield pytest.param(dist, marks=pytest.mark.xslow(reason=reason))
  219. elif dist in warns_basic_fit:
  220. mark = pytest.mark.filterwarnings('ignore::RuntimeWarning')
  221. yield pytest.param(dist, marks=mark)
  222. else:
  223. yield dist
  224. def cases_test_fitstart():
  225. for distname, shapes in dict(distcont).items():
  226. if (not isinstance(distname, str) or
  227. distname in {'studentized_range', 'recipinvgauss'}): # slow
  228. continue
  229. yield distname, shapes
  230. @pytest.mark.parametrize('distname, shapes', cases_test_fitstart())
  231. def test_fitstart(distname, shapes):
  232. dist = getattr(stats, distname)
  233. rng = np.random.default_rng(216342614)
  234. data = rng.random(10)
  235. with np.errstate(invalid='ignore', divide='ignore'): # irrelevant to test
  236. guess = dist._fitstart(data)
  237. assert dist._argcheck(*guess[:-2])
  238. def assert_nlff_less_or_close(dist, data, params1, params0, rtol=1e-7, atol=0,
  239. nlff_name='nnlf'):
  240. nlff = getattr(dist, nlff_name)
  241. nlff1 = nlff(params1, data)
  242. nlff0 = nlff(params0, data)
  243. if not (nlff1 < nlff0):
  244. np.testing.assert_allclose(nlff1, nlff0, rtol=rtol, atol=atol)
  245. class TestFit:
  246. dist = stats.binom # type: ignore[attr-defined]
  247. seed = 654634816187
  248. rng = np.random.default_rng(seed)
  249. data = stats.binom.rvs(5, 0.5, size=100, random_state=rng) # type: ignore[attr-defined] # noqa
  250. shape_bounds_a = [(1, 10), (0, 1)]
  251. shape_bounds_d = {'n': (1, 10), 'p': (0, 1)}
  252. atol = 5e-2
  253. rtol = 1e-2
  254. tols = {'atol': atol, 'rtol': rtol}
  255. def opt(self, *args, **kwds):
  256. return differential_evolution(*args, seed=0, **kwds)
  257. def test_dist_iv(self):
  258. message = "`dist` must be an instance of..."
  259. with pytest.raises(ValueError, match=message):
  260. stats.fit(10, self.data, self.shape_bounds_a)
  261. def test_data_iv(self):
  262. message = "`data` must be exactly one-dimensional."
  263. with pytest.raises(ValueError, match=message):
  264. stats.fit(self.dist, [[1, 2, 3]], self.shape_bounds_a)
  265. message = "All elements of `data` must be finite numbers."
  266. with pytest.raises(ValueError, match=message):
  267. stats.fit(self.dist, [1, 2, 3, np.nan], self.shape_bounds_a)
  268. with pytest.raises(ValueError, match=message):
  269. stats.fit(self.dist, [1, 2, 3, np.inf], self.shape_bounds_a)
  270. with pytest.raises(ValueError, match=message):
  271. stats.fit(self.dist, ['1', '2', '3'], self.shape_bounds_a)
  272. def test_bounds_iv(self):
  273. message = "Bounds provided for the following unrecognized..."
  274. shape_bounds = {'n': (1, 10), 'p': (0, 1), '1': (0, 10)}
  275. with pytest.warns(RuntimeWarning, match=message):
  276. stats.fit(self.dist, self.data, shape_bounds)
  277. message = "Each element of a `bounds` sequence must be a tuple..."
  278. shape_bounds = [(1, 10, 3), (0, 1)]
  279. with pytest.raises(ValueError, match=message):
  280. stats.fit(self.dist, self.data, shape_bounds)
  281. message = "Each element of `bounds` must be a tuple specifying..."
  282. shape_bounds = [(1, 10, 3), (0, 1, 0.5)]
  283. with pytest.raises(ValueError, match=message):
  284. stats.fit(self.dist, self.data, shape_bounds)
  285. shape_bounds = [1, 0]
  286. with pytest.raises(ValueError, match=message):
  287. stats.fit(self.dist, self.data, shape_bounds)
  288. message = "A `bounds` sequence must contain at least 2 elements..."
  289. shape_bounds = [(1, 10)]
  290. with pytest.raises(ValueError, match=message):
  291. stats.fit(self.dist, self.data, shape_bounds)
  292. message = "A `bounds` sequence may not contain more than 3 elements..."
  293. bounds = [(1, 10), (1, 10), (1, 10), (1, 10)]
  294. with pytest.raises(ValueError, match=message):
  295. stats.fit(self.dist, self.data, bounds)
  296. message = "There are no values for `p` on the interval..."
  297. shape_bounds = {'n': (1, 10), 'p': (1, 0)}
  298. with pytest.raises(ValueError, match=message):
  299. stats.fit(self.dist, self.data, shape_bounds)
  300. message = "There are no values for `n` on the interval..."
  301. shape_bounds = [(10, 1), (0, 1)]
  302. with pytest.raises(ValueError, match=message):
  303. stats.fit(self.dist, self.data, shape_bounds)
  304. message = "There are no integer values for `n` on the interval..."
  305. shape_bounds = [(1.4, 1.6), (0, 1)]
  306. with pytest.raises(ValueError, match=message):
  307. stats.fit(self.dist, self.data, shape_bounds)
  308. message = "The intersection of user-provided bounds for `n`"
  309. with pytest.raises(ValueError, match=message):
  310. stats.fit(self.dist, self.data)
  311. shape_bounds = [(-np.inf, np.inf), (0, 1)]
  312. with pytest.raises(ValueError, match=message):
  313. stats.fit(self.dist, self.data, shape_bounds)
  314. def test_guess_iv(self):
  315. message = "Guesses provided for the following unrecognized..."
  316. guess = {'n': 1, 'p': 0.5, '1': 255}
  317. with pytest.warns(RuntimeWarning, match=message):
  318. stats.fit(self.dist, self.data, self.shape_bounds_d, guess=guess)
  319. message = "Each element of `guess` must be a scalar..."
  320. guess = {'n': 1, 'p': 'hi'}
  321. with pytest.raises(ValueError, match=message):
  322. stats.fit(self.dist, self.data, self.shape_bounds_d, guess=guess)
  323. guess = [1, 'f']
  324. with pytest.raises(ValueError, match=message):
  325. stats.fit(self.dist, self.data, self.shape_bounds_d, guess=guess)
  326. guess = [[1, 2]]
  327. with pytest.raises(ValueError, match=message):
  328. stats.fit(self.dist, self.data, self.shape_bounds_d, guess=guess)
  329. message = "A `guess` sequence must contain at least 2..."
  330. guess = [1]
  331. with pytest.raises(ValueError, match=message):
  332. stats.fit(self.dist, self.data, self.shape_bounds_d, guess=guess)
  333. message = "A `guess` sequence may not contain more than 3..."
  334. guess = [1, 2, 3, 4]
  335. with pytest.raises(ValueError, match=message):
  336. stats.fit(self.dist, self.data, self.shape_bounds_d, guess=guess)
  337. message = "Guess for parameter `n` rounded..."
  338. guess = {'n': 4.5, 'p': -0.5}
  339. with pytest.warns(RuntimeWarning, match=message):
  340. stats.fit(self.dist, self.data, self.shape_bounds_d, guess=guess)
  341. message = "Guess for parameter `loc` rounded..."
  342. guess = [5, 0.5, 0.5]
  343. with pytest.warns(RuntimeWarning, match=message):
  344. stats.fit(self.dist, self.data, self.shape_bounds_d, guess=guess)
  345. message = "Guess for parameter `p` clipped..."
  346. guess = {'n': 5, 'p': -0.5}
  347. with pytest.warns(RuntimeWarning, match=message):
  348. stats.fit(self.dist, self.data, self.shape_bounds_d, guess=guess)
  349. message = "Guess for parameter `loc` clipped..."
  350. guess = [5, 0.5, 1]
  351. with pytest.warns(RuntimeWarning, match=message):
  352. stats.fit(self.dist, self.data, self.shape_bounds_d, guess=guess)
  353. def basic_fit_test(self, dist_name, method):
  354. N = 5000
  355. dist_data = dict(distcont + distdiscrete)
  356. rng = np.random.default_rng(self.seed)
  357. dist = getattr(stats, dist_name)
  358. shapes = np.array(dist_data[dist_name])
  359. bounds = np.empty((len(shapes) + 2, 2), dtype=np.float64)
  360. bounds[:-2, 0] = shapes/10.**np.sign(shapes)
  361. bounds[:-2, 1] = shapes*10.**np.sign(shapes)
  362. bounds[-2] = (0, 10)
  363. bounds[-1] = (1e-16, 10)
  364. loc = rng.uniform(*bounds[-2])
  365. scale = rng.uniform(*bounds[-1])
  366. ref = list(dist_data[dist_name]) + [loc, scale]
  367. if getattr(dist, 'pmf', False):
  368. ref = ref[:-1]
  369. ref[-1] = np.floor(loc)
  370. data = dist.rvs(*ref, size=N, random_state=rng)
  371. bounds = bounds[:-1]
  372. if getattr(dist, 'pdf', False):
  373. data = dist.rvs(*ref, size=N, random_state=rng)
  374. with npt.suppress_warnings() as sup:
  375. sup.filter(RuntimeWarning, "overflow encountered")
  376. res = stats.fit(dist, data, bounds, method=method,
  377. optimizer=self.opt)
  378. nlff_names = {'mle': 'nnlf', 'mse': '_penalized_nlpsf'}
  379. nlff_name = nlff_names[method]
  380. assert_nlff_less_or_close(dist, data, res.params, ref, **self.tols,
  381. nlff_name=nlff_name)
  382. @pytest.mark.parametrize("dist_name", cases_test_fit_mle())
  383. def test_basic_fit_mle(self, dist_name):
  384. self.basic_fit_test(dist_name, "mle")
  385. @pytest.mark.parametrize("dist_name", cases_test_fit_mse())
  386. def test_basic_fit_mse(self, dist_name):
  387. self.basic_fit_test(dist_name, "mse")
  388. def test_argus(self):
  389. # Can't guarantee that all distributions will fit all data with
  390. # arbitrary bounds. This distribution just happens to fail above.
  391. # Try something slightly different.
  392. N = 1000
  393. rng = np.random.default_rng(self.seed)
  394. dist = stats.argus
  395. shapes = (1., 2., 3.)
  396. data = dist.rvs(*shapes, size=N, random_state=rng)
  397. shape_bounds = {'chi': (0.1, 10), 'loc': (0.1, 10), 'scale': (0.1, 10)}
  398. res = stats.fit(dist, data, shape_bounds, optimizer=self.opt)
  399. assert_nlff_less_or_close(dist, data, res.params, shapes, **self.tols)
  400. def test_foldnorm(self):
  401. # Can't guarantee that all distributions will fit all data with
  402. # arbitrary bounds. This distribution just happens to fail above.
  403. # Try something slightly different.
  404. N = 1000
  405. rng = np.random.default_rng(self.seed)
  406. dist = stats.foldnorm
  407. shapes = (1.952125337355587, 2., 3.)
  408. data = dist.rvs(*shapes, size=N, random_state=rng)
  409. shape_bounds = {'c': (0.1, 10), 'loc': (0.1, 10), 'scale': (0.1, 10)}
  410. res = stats.fit(dist, data, shape_bounds, optimizer=self.opt)
  411. assert_nlff_less_or_close(dist, data, res.params, shapes, **self.tols)
  412. def test_truncpareto(self):
  413. # Can't guarantee that all distributions will fit all data with
  414. # arbitrary bounds. This distribution just happens to fail above.
  415. # Try something slightly different.
  416. N = 1000
  417. rng = np.random.default_rng(self.seed)
  418. dist = stats.truncpareto
  419. shapes = (1.8, 5.3, 2.3, 4.1)
  420. data = dist.rvs(*shapes, size=N, random_state=rng)
  421. shape_bounds = [(0.1, 10)]*4
  422. res = stats.fit(dist, data, shape_bounds, optimizer=self.opt)
  423. assert_nlff_less_or_close(dist, data, res.params, shapes, **self.tols)
  424. def test_truncweibull_min(self):
  425. # Can't guarantee that all distributions will fit all data with
  426. # arbitrary bounds. This distribution just happens to fail above.
  427. # Try something slightly different.
  428. N = 1000
  429. rng = np.random.default_rng(self.seed)
  430. dist = stats.truncweibull_min
  431. shapes = (2.5, 0.25, 1.75, 2., 3.)
  432. data = dist.rvs(*shapes, size=N, random_state=rng)
  433. shape_bounds = [(0.1, 10)]*5
  434. res = stats.fit(dist, data, shape_bounds, optimizer=self.opt)
  435. assert_nlff_less_or_close(dist, data, res.params, shapes, **self.tols)
  436. def test_missing_shape_bounds(self):
  437. # some distributions have a small domain w.r.t. a parameter, e.g.
  438. # $p \in [0, 1]$ for binomial distribution
  439. # User does not need to provide these because the intersection of the
  440. # user's bounds (none) and the distribution's domain is finite
  441. N = 1000
  442. rng = np.random.default_rng(self.seed)
  443. dist = stats.binom
  444. n, p, loc = 10, 0.65, 0
  445. data = dist.rvs(n, p, loc=loc, size=N, random_state=rng)
  446. shape_bounds = {'n': np.array([0, 20])} # check arrays are OK, too
  447. res = stats.fit(dist, data, shape_bounds, optimizer=self.opt)
  448. assert_allclose(res.params, (n, p, loc), **self.tols)
  449. dist = stats.bernoulli
  450. p, loc = 0.314159, 0
  451. data = dist.rvs(p, loc=loc, size=N, random_state=rng)
  452. res = stats.fit(dist, data, optimizer=self.opt)
  453. assert_allclose(res.params, (p, loc), **self.tols)
  454. def test_fit_only_loc_scale(self):
  455. # fit only loc
  456. N = 5000
  457. rng = np.random.default_rng(self.seed)
  458. dist = stats.norm
  459. loc, scale = 1.5, 1
  460. data = dist.rvs(loc=loc, size=N, random_state=rng)
  461. loc_bounds = (0, 5)
  462. bounds = {'loc': loc_bounds}
  463. res = stats.fit(dist, data, bounds, optimizer=self.opt)
  464. assert_allclose(res.params, (loc, scale), **self.tols)
  465. # fit only scale
  466. loc, scale = 0, 2.5
  467. data = dist.rvs(scale=scale, size=N, random_state=rng)
  468. scale_bounds = (0, 5)
  469. bounds = {'scale': scale_bounds}
  470. res = stats.fit(dist, data, bounds, optimizer=self.opt)
  471. assert_allclose(res.params, (loc, scale), **self.tols)
  472. # fit only loc and scale
  473. dist = stats.norm
  474. loc, scale = 1.5, 2.5
  475. data = dist.rvs(loc=loc, scale=scale, size=N, random_state=rng)
  476. bounds = {'loc': loc_bounds, 'scale': scale_bounds}
  477. res = stats.fit(dist, data, bounds, optimizer=self.opt)
  478. assert_allclose(res.params, (loc, scale), **self.tols)
  479. def test_everything_fixed(self):
  480. N = 5000
  481. rng = np.random.default_rng(self.seed)
  482. dist = stats.norm
  483. loc, scale = 1.5, 2.5
  484. data = dist.rvs(loc=loc, scale=scale, size=N, random_state=rng)
  485. # loc, scale fixed to 0, 1 by default
  486. res = stats.fit(dist, data)
  487. assert_allclose(res.params, (0, 1), **self.tols)
  488. # loc, scale explicitly fixed
  489. bounds = {'loc': (loc, loc), 'scale': (scale, scale)}
  490. res = stats.fit(dist, data, bounds)
  491. assert_allclose(res.params, (loc, scale), **self.tols)
  492. # `n` gets fixed during polishing
  493. dist = stats.binom
  494. n, p, loc = 10, 0.65, 0
  495. data = dist.rvs(n, p, loc=loc, size=N, random_state=rng)
  496. shape_bounds = {'n': (0, 20), 'p': (0.65, 0.65)}
  497. res = stats.fit(dist, data, shape_bounds, optimizer=self.opt)
  498. assert_allclose(res.params, (n, p, loc), **self.tols)
  499. def test_failure(self):
  500. N = 5000
  501. rng = np.random.default_rng(self.seed)
  502. dist = stats.nbinom
  503. shapes = (5, 0.5)
  504. data = dist.rvs(*shapes, size=N, random_state=rng)
  505. assert data.min() == 0
  506. # With lower bounds on location at 0.5, likelihood is zero
  507. bounds = [(0, 30), (0, 1), (0.5, 10)]
  508. res = stats.fit(dist, data, bounds)
  509. message = "Optimization converged to parameter values that are"
  510. assert res.message.startswith(message)
  511. assert res.success is False
  512. @pytest.mark.xslow
  513. def test_guess(self):
  514. # Test that guess helps DE find the desired solution
  515. N = 2000
  516. rng = np.random.default_rng(self.seed)
  517. dist = stats.nhypergeom
  518. params = (20, 7, 12, 0)
  519. bounds = [(2, 200), (0.7, 70), (1.2, 120), (0, 10)]
  520. data = dist.rvs(*params, size=N, random_state=rng)
  521. res = stats.fit(dist, data, bounds, optimizer=self.opt)
  522. assert not np.allclose(res.params, params, **self.tols)
  523. res = stats.fit(dist, data, bounds, guess=params, optimizer=self.opt)
  524. assert_allclose(res.params, params, **self.tols)
  525. def test_mse_accuracy_1(self):
  526. # Test maximum spacing estimation against example from Wikipedia
  527. # https://en.wikipedia.org/wiki/Maximum_spacing_estimation#Examples
  528. data = [2, 4]
  529. dist = stats.expon
  530. bounds = {'loc': (0, 0), 'scale': (1e-8, 10)}
  531. res_mle = stats.fit(dist, data, bounds=bounds, method='mle')
  532. assert_allclose(res_mle.params.scale, 3, atol=1e-3)
  533. res_mse = stats.fit(dist, data, bounds=bounds, method='mse')
  534. assert_allclose(res_mse.params.scale, 3.915, atol=1e-3)
  535. def test_mse_accuracy_2(self):
  536. # Test maximum spacing estimation against example from Wikipedia
  537. # https://en.wikipedia.org/wiki/Maximum_spacing_estimation#Examples
  538. rng = np.random.default_rng(9843212616816518964)
  539. dist = stats.uniform
  540. n = 10
  541. data = dist(3, 6).rvs(size=n, random_state=rng)
  542. bounds = {'loc': (0, 10), 'scale': (1e-8, 10)}
  543. res = stats.fit(dist, data, bounds=bounds, method='mse')
  544. # (loc=3.608118420015416, scale=5.509323262055043)
  545. x = np.sort(data)
  546. a = (n*x[0] - x[-1])/(n - 1)
  547. b = (n*x[-1] - x[0])/(n - 1)
  548. ref = a, b-a # (3.6081133632151503, 5.509328130317254)
  549. assert_allclose(res.params, ref, rtol=1e-4)
  550. # Data from Matlab: https://www.mathworks.com/help/stats/lillietest.html
  551. examgrades = [65, 61, 81, 88, 69, 89, 55, 84, 86, 84, 71, 81, 84, 81, 78, 67,
  552. 96, 66, 73, 75, 59, 71, 69, 63, 79, 76, 63, 85, 87, 88, 80, 71,
  553. 65, 84, 71, 75, 81, 79, 64, 65, 84, 77, 70, 75, 84, 75, 73, 92,
  554. 90, 79, 80, 71, 73, 71, 58, 79, 73, 64, 77, 82, 81, 59, 54, 82,
  555. 57, 79, 79, 73, 74, 82, 63, 64, 73, 69, 87, 68, 81, 73, 83, 73,
  556. 80, 73, 73, 71, 66, 78, 64, 74, 68, 67, 75, 75, 80, 85, 74, 76,
  557. 80, 77, 93, 70, 86, 80, 81, 83, 68, 60, 85, 64, 74, 82, 81, 77,
  558. 66, 85, 75, 81, 69, 60, 83, 72]
  559. class TestGoodnessOfFit:
  560. def test_gof_iv(self):
  561. dist = stats.norm
  562. x = [1, 2, 3]
  563. message = r"`dist` must be a \(non-frozen\) instance of..."
  564. with pytest.raises(TypeError, match=message):
  565. goodness_of_fit(stats.norm(), x)
  566. message = "`data` must be a one-dimensional array of numbers."
  567. with pytest.raises(ValueError, match=message):
  568. goodness_of_fit(dist, [[1, 2, 3]])
  569. message = "`statistic` must be one of..."
  570. with pytest.raises(ValueError, match=message):
  571. goodness_of_fit(dist, x, statistic='mm')
  572. message = "`n_mc_samples` must be an integer."
  573. with pytest.raises(TypeError, match=message):
  574. goodness_of_fit(dist, x, n_mc_samples=1000.5)
  575. message = "'herring' cannot be used to seed a"
  576. with pytest.raises(ValueError, match=message):
  577. goodness_of_fit(dist, x, random_state='herring')
  578. def test_against_ks(self):
  579. rng = np.random.default_rng(8517426291317196949)
  580. x = examgrades
  581. known_params = {'loc': np.mean(x), 'scale': np.std(x, ddof=1)}
  582. res = goodness_of_fit(stats.norm, x, known_params=known_params,
  583. statistic='ks', random_state=rng)
  584. ref = stats.kstest(x, stats.norm(**known_params).cdf, method='exact')
  585. assert_allclose(res.statistic, ref.statistic) # ~0.0848
  586. assert_allclose(res.pvalue, ref.pvalue, atol=5e-3) # ~0.335
  587. def test_against_lilliefors(self):
  588. rng = np.random.default_rng(2291803665717442724)
  589. x = examgrades
  590. res = goodness_of_fit(stats.norm, x, statistic='ks', random_state=rng)
  591. known_params = {'loc': np.mean(x), 'scale': np.std(x, ddof=1)}
  592. ref = stats.kstest(x, stats.norm(**known_params).cdf, method='exact')
  593. assert_allclose(res.statistic, ref.statistic) # ~0.0848
  594. assert_allclose(res.pvalue, 0.0348, atol=5e-3)
  595. def test_against_cvm(self):
  596. rng = np.random.default_rng(8674330857509546614)
  597. x = examgrades
  598. known_params = {'loc': np.mean(x), 'scale': np.std(x, ddof=1)}
  599. res = goodness_of_fit(stats.norm, x, known_params=known_params,
  600. statistic='cvm', random_state=rng)
  601. ref = stats.cramervonmises(x, stats.norm(**known_params).cdf)
  602. assert_allclose(res.statistic, ref.statistic) # ~0.090
  603. assert_allclose(res.pvalue, ref.pvalue, atol=5e-3) # ~0.636
  604. def test_against_anderson_case_0(self):
  605. # "Case 0" is where loc and scale are known [1]
  606. rng = np.random.default_rng(7384539336846690410)
  607. x = np.arange(1, 101)
  608. # loc that produced critical value of statistic found w/ root_scalar
  609. known_params = {'loc': 45.01575354024957, 'scale': 30}
  610. res = goodness_of_fit(stats.norm, x, known_params=known_params,
  611. statistic='ad', random_state=rng)
  612. assert_allclose(res.statistic, 2.492) # See [1] Table 1A 1.0
  613. assert_allclose(res.pvalue, 0.05, atol=5e-3)
  614. def test_against_anderson_case_1(self):
  615. # "Case 1" is where scale is known and loc is fit [1]
  616. rng = np.random.default_rng(5040212485680146248)
  617. x = np.arange(1, 101)
  618. # scale that produced critical value of statistic found w/ root_scalar
  619. known_params = {'scale': 29.957112639101933}
  620. res = goodness_of_fit(stats.norm, x, known_params=known_params,
  621. statistic='ad', random_state=rng)
  622. assert_allclose(res.statistic, 0.908) # See [1] Table 1B 1.1
  623. assert_allclose(res.pvalue, 0.1, atol=5e-3)
  624. def test_against_anderson_case_2(self):
  625. # "Case 2" is where loc is known and scale is fit [1]
  626. rng = np.random.default_rng(726693985720914083)
  627. x = np.arange(1, 101)
  628. # loc that produced critical value of statistic found w/ root_scalar
  629. known_params = {'loc': 44.5680212261933}
  630. res = goodness_of_fit(stats.norm, x, known_params=known_params,
  631. statistic='ad', random_state=rng)
  632. assert_allclose(res.statistic, 2.904) # See [1] Table 1B 1.2
  633. assert_allclose(res.pvalue, 0.025, atol=5e-3)
  634. def test_against_anderson_case_3(self):
  635. # "Case 3" is where both loc and scale are fit [1]
  636. rng = np.random.default_rng(6763691329830218206)
  637. # c that produced critical value of statistic found w/ root_scalar
  638. x = stats.skewnorm.rvs(1.4477847789132101, loc=1, scale=2, size=100,
  639. random_state=rng)
  640. res = goodness_of_fit(stats.norm, x, statistic='ad', random_state=rng)
  641. assert_allclose(res.statistic, 0.559) # See [1] Table 1B 1.2
  642. assert_allclose(res.pvalue, 0.15, atol=5e-3)
  643. @pytest.mark.slow
  644. def test_against_anderson_gumbel_r(self):
  645. rng = np.random.default_rng(7302761058217743)
  646. # c that produced critical value of statistic found w/ root_scalar
  647. x = stats.genextreme(0.051896837188595134, loc=0.5,
  648. scale=1.5).rvs(size=1000, random_state=rng)
  649. res = goodness_of_fit(stats.gumbel_r, x, statistic='ad',
  650. random_state=rng)
  651. ref = stats.anderson(x, dist='gumbel_r')
  652. assert_allclose(res.statistic, ref.critical_values[0])
  653. assert_allclose(res.pvalue, ref.significance_level[0]/100, atol=5e-3)
  654. def test_params_effects(self):
  655. # Ensure that `guessed_params`, `fit_params`, and `known_params` have
  656. # the intended effects.
  657. rng = np.random.default_rng(9121950977643805391)
  658. x = stats.skewnorm.rvs(-5.044559778383153, loc=1, scale=2, size=50,
  659. random_state=rng)
  660. # Show that `guessed_params` don't fit to the guess,
  661. # but `fit_params` and `known_params` respect the provided fit
  662. guessed_params = {'c': 13.4}
  663. fit_params = {'scale': 13.73}
  664. known_params = {'loc': -13.85}
  665. rng = np.random.default_rng(9121950977643805391)
  666. res1 = goodness_of_fit(stats.weibull_min, x, n_mc_samples=2,
  667. guessed_params=guessed_params,
  668. fit_params=fit_params,
  669. known_params=known_params, random_state=rng)
  670. assert not np.allclose(res1.fit_result.params.c, 13.4)
  671. assert_equal(res1.fit_result.params.scale, 13.73)
  672. assert_equal(res1.fit_result.params.loc, -13.85)
  673. # Show that changing the guess changes the parameter that gets fit,
  674. # and it changes the null distribution
  675. guessed_params = {'c': 2}
  676. rng = np.random.default_rng(9121950977643805391)
  677. res2 = goodness_of_fit(stats.weibull_min, x, n_mc_samples=2,
  678. guessed_params=guessed_params,
  679. fit_params=fit_params,
  680. known_params=known_params, random_state=rng)
  681. assert not np.allclose(res2.fit_result.params.c,
  682. res1.fit_result.params.c, rtol=1e-8)
  683. assert not np.allclose(res2.null_distribution,
  684. res1.null_distribution, rtol=1e-8)
  685. assert_equal(res2.fit_result.params.scale, 13.73)
  686. assert_equal(res2.fit_result.params.loc, -13.85)
  687. # If we set all parameters as fit_params and known_params,
  688. # they're all fixed to those values, but the null distribution
  689. # varies.
  690. fit_params = {'c': 13.4, 'scale': 13.73}
  691. rng = np.random.default_rng(9121950977643805391)
  692. res3 = goodness_of_fit(stats.weibull_min, x, n_mc_samples=2,
  693. guessed_params=guessed_params,
  694. fit_params=fit_params,
  695. known_params=known_params, random_state=rng)
  696. assert_equal(res3.fit_result.params.c, 13.4)
  697. assert_equal(res3.fit_result.params.scale, 13.73)
  698. assert_equal(res3.fit_result.params.loc, -13.85)
  699. assert not np.allclose(res3.null_distribution, res1.null_distribution)
  700. class TestFitResult:
  701. def test_plot_iv(self):
  702. rng = np.random.default_rng(1769658657308472721)
  703. data = stats.norm.rvs(0, 1, size=100, random_state=rng)
  704. def optimizer(*args, **kwargs):
  705. return differential_evolution(*args, **kwargs, seed=rng)
  706. bounds = [(0, 30), (0, 1)]
  707. res = stats.fit(stats.norm, data, bounds, optimizer=optimizer)
  708. try:
  709. import matplotlib # noqa
  710. message = r"`plot_type` must be one of \{'..."
  711. with pytest.raises(ValueError, match=message):
  712. res.plot(plot_type='llama')
  713. except (ModuleNotFoundError, ImportError):
  714. message = r"matplotlib must be installed to use method `plot`."
  715. with pytest.raises(ModuleNotFoundError, match=message):
  716. res.plot(plot_type='llama')