vq.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. """
  2. K-means clustering and vector quantization (:mod:`scipy.cluster.vq`)
  3. ====================================================================
  4. Provides routines for k-means clustering, generating code books
  5. from k-means models and quantizing vectors by comparing them with
  6. centroids in a code book.
  7. .. autosummary::
  8. :toctree: generated/
  9. whiten -- Normalize a group of observations so each feature has unit variance
  10. vq -- Calculate code book membership of a set of observation vectors
  11. kmeans -- Perform k-means on a set of observation vectors forming k clusters
  12. kmeans2 -- A different implementation of k-means with more methods
  13. -- for initializing centroids
  14. Background information
  15. ----------------------
  16. The k-means algorithm takes as input the number of clusters to
  17. generate, k, and a set of observation vectors to cluster. It
  18. returns a set of centroids, one for each of the k clusters. An
  19. observation vector is classified with the cluster number or
  20. centroid index of the centroid closest to it.
  21. A vector v belongs to cluster i if it is closer to centroid i than
  22. any other centroid. If v belongs to i, we say centroid i is the
  23. dominating centroid of v. The k-means algorithm tries to
  24. minimize distortion, which is defined as the sum of the squared distances
  25. between each observation vector and its dominating centroid.
  26. The minimization is achieved by iteratively reclassifying
  27. the observations into clusters and recalculating the centroids until
  28. a configuration is reached in which the centroids are stable. One can
  29. also define a maximum number of iterations.
  30. Since vector quantization is a natural application for k-means,
  31. information theory terminology is often used. The centroid index
  32. or cluster index is also referred to as a "code" and the table
  33. mapping codes to centroids and, vice versa, is often referred to as a
  34. "code book". The result of k-means, a set of centroids, can be
  35. used to quantize vectors. Quantization aims to find an encoding of
  36. vectors that reduces the expected distortion.
  37. All routines expect obs to be an M by N array, where the rows are
  38. the observation vectors. The codebook is a k by N array, where the
  39. ith row is the centroid of code word i. The observation vectors
  40. and centroids have the same feature dimension.
  41. As an example, suppose we wish to compress a 24-bit color image
  42. (each pixel is represented by one byte for red, one for blue, and
  43. one for green) before sending it over the web. By using a smaller
  44. 8-bit encoding, we can reduce the amount of data by two
  45. thirds. Ideally, the colors for each of the 256 possible 8-bit
  46. encoding values should be chosen to minimize distortion of the
  47. color. Running k-means with k=256 generates a code book of 256
  48. codes, which fills up all possible 8-bit sequences. Instead of
  49. sending a 3-byte value for each pixel, the 8-bit centroid index
  50. (or code word) of the dominating centroid is transmitted. The code
  51. book is also sent over the wire so each 8-bit code can be
  52. translated back to a 24-bit pixel value representation. If the
  53. image of interest was of an ocean, we would expect many 24-bit
  54. blues to be represented by 8-bit codes. If it was an image of a
  55. human face, more flesh-tone colors would be represented in the
  56. code book.
  57. """
  58. import warnings
  59. import numpy as np
  60. from collections import deque
  61. from scipy._lib._util import _asarray_validated, check_random_state,\
  62. rng_integers
  63. from scipy.spatial.distance import cdist
  64. from . import _vq
  65. __docformat__ = 'restructuredtext'
  66. __all__ = ['whiten', 'vq', 'kmeans', 'kmeans2']
  67. class ClusterError(Exception):
  68. pass
  69. def whiten(obs, check_finite=True):
  70. """
  71. Normalize a group of observations on a per feature basis.
  72. Before running k-means, it is beneficial to rescale each feature
  73. dimension of the observation set by its standard deviation (i.e. "whiten"
  74. it - as in "white noise" where each frequency has equal power).
  75. Each feature is divided by its standard deviation across all observations
  76. to give it unit variance.
  77. Parameters
  78. ----------
  79. obs : ndarray
  80. Each row of the array is an observation. The
  81. columns are the features seen during each observation.
  82. >>> # f0 f1 f2
  83. >>> obs = [[ 1., 1., 1.], #o0
  84. ... [ 2., 2., 2.], #o1
  85. ... [ 3., 3., 3.], #o2
  86. ... [ 4., 4., 4.]] #o3
  87. check_finite : bool, optional
  88. Whether to check that the input matrices contain only finite numbers.
  89. Disabling may give a performance gain, but may result in problems
  90. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  91. Default: True
  92. Returns
  93. -------
  94. result : ndarray
  95. Contains the values in `obs` scaled by the standard deviation
  96. of each column.
  97. Examples
  98. --------
  99. >>> import numpy as np
  100. >>> from scipy.cluster.vq import whiten
  101. >>> features = np.array([[1.9, 2.3, 1.7],
  102. ... [1.5, 2.5, 2.2],
  103. ... [0.8, 0.6, 1.7,]])
  104. >>> whiten(features)
  105. array([[ 4.17944278, 2.69811351, 7.21248917],
  106. [ 3.29956009, 2.93273208, 9.33380951],
  107. [ 1.75976538, 0.7038557 , 7.21248917]])
  108. """
  109. obs = _asarray_validated(obs, check_finite=check_finite)
  110. std_dev = obs.std(axis=0)
  111. zero_std_mask = std_dev == 0
  112. if zero_std_mask.any():
  113. std_dev[zero_std_mask] = 1.0
  114. warnings.warn("Some columns have standard deviation zero. "
  115. "The values of these columns will not change.",
  116. RuntimeWarning)
  117. return obs / std_dev
  118. def vq(obs, code_book, check_finite=True):
  119. """
  120. Assign codes from a code book to observations.
  121. Assigns a code from a code book to each observation. Each
  122. observation vector in the 'M' by 'N' `obs` array is compared with the
  123. centroids in the code book and assigned the code of the closest
  124. centroid.
  125. The features in `obs` should have unit variance, which can be
  126. achieved by passing them through the whiten function. The code
  127. book can be created with the k-means algorithm or a different
  128. encoding algorithm.
  129. Parameters
  130. ----------
  131. obs : ndarray
  132. Each row of the 'M' x 'N' array is an observation. The columns are
  133. the "features" seen during each observation. The features must be
  134. whitened first using the whiten function or something equivalent.
  135. code_book : ndarray
  136. The code book is usually generated using the k-means algorithm.
  137. Each row of the array holds a different code, and the columns are
  138. the features of the code.
  139. >>> # f0 f1 f2 f3
  140. >>> code_book = [
  141. ... [ 1., 2., 3., 4.], #c0
  142. ... [ 1., 2., 3., 4.], #c1
  143. ... [ 1., 2., 3., 4.]] #c2
  144. check_finite : bool, optional
  145. Whether to check that the input matrices contain only finite numbers.
  146. Disabling may give a performance gain, but may result in problems
  147. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  148. Default: True
  149. Returns
  150. -------
  151. code : ndarray
  152. A length M array holding the code book index for each observation.
  153. dist : ndarray
  154. The distortion (distance) between the observation and its nearest
  155. code.
  156. Examples
  157. --------
  158. >>> import numpy as np
  159. >>> from scipy.cluster.vq import vq
  160. >>> code_book = np.array([[1.,1.,1.],
  161. ... [2.,2.,2.]])
  162. >>> features = np.array([[ 1.9,2.3,1.7],
  163. ... [ 1.5,2.5,2.2],
  164. ... [ 0.8,0.6,1.7]])
  165. >>> vq(features,code_book)
  166. (array([1, 1, 0],'i'), array([ 0.43588989, 0.73484692, 0.83066239]))
  167. """
  168. obs = _asarray_validated(obs, check_finite=check_finite)
  169. code_book = _asarray_validated(code_book, check_finite=check_finite)
  170. ct = np.common_type(obs, code_book)
  171. c_obs = obs.astype(ct, copy=False)
  172. c_code_book = code_book.astype(ct, copy=False)
  173. if np.issubdtype(ct, np.float64) or np.issubdtype(ct, np.float32):
  174. return _vq.vq(c_obs, c_code_book)
  175. return py_vq(obs, code_book, check_finite=False)
  176. def py_vq(obs, code_book, check_finite=True):
  177. """ Python version of vq algorithm.
  178. The algorithm computes the Euclidean distance between each
  179. observation and every frame in the code_book.
  180. Parameters
  181. ----------
  182. obs : ndarray
  183. Expects a rank 2 array. Each row is one observation.
  184. code_book : ndarray
  185. Code book to use. Same format than obs. Should have same number of
  186. features (e.g., columns) than obs.
  187. check_finite : bool, optional
  188. Whether to check that the input matrices contain only finite numbers.
  189. Disabling may give a performance gain, but may result in problems
  190. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  191. Default: True
  192. Returns
  193. -------
  194. code : ndarray
  195. code[i] gives the label of the ith obversation; its code is
  196. code_book[code[i]].
  197. mind_dist : ndarray
  198. min_dist[i] gives the distance between the ith observation and its
  199. corresponding code.
  200. Notes
  201. -----
  202. This function is slower than the C version but works for
  203. all input types. If the inputs have the wrong types for the
  204. C versions of the function, this one is called as a last resort.
  205. It is about 20 times slower than the C version.
  206. """
  207. obs = _asarray_validated(obs, check_finite=check_finite)
  208. code_book = _asarray_validated(code_book, check_finite=check_finite)
  209. if obs.ndim != code_book.ndim:
  210. raise ValueError("Observation and code_book should have the same rank")
  211. if obs.ndim == 1:
  212. obs = obs[:, np.newaxis]
  213. code_book = code_book[:, np.newaxis]
  214. dist = cdist(obs, code_book)
  215. code = dist.argmin(axis=1)
  216. min_dist = dist[np.arange(len(code)), code]
  217. return code, min_dist
  218. def _kmeans(obs, guess, thresh=1e-5):
  219. """ "raw" version of k-means.
  220. Returns
  221. -------
  222. code_book
  223. The lowest distortion codebook found.
  224. avg_dist
  225. The average distance a observation is from a code in the book.
  226. Lower means the code_book matches the data better.
  227. See Also
  228. --------
  229. kmeans : wrapper around k-means
  230. Examples
  231. --------
  232. Note: not whitened in this example.
  233. >>> import numpy as np
  234. >>> from scipy.cluster.vq import _kmeans
  235. >>> features = np.array([[ 1.9,2.3],
  236. ... [ 1.5,2.5],
  237. ... [ 0.8,0.6],
  238. ... [ 0.4,1.8],
  239. ... [ 1.0,1.0]])
  240. >>> book = np.array((features[0],features[2]))
  241. >>> _kmeans(features,book)
  242. (array([[ 1.7 , 2.4 ],
  243. [ 0.73333333, 1.13333333]]), 0.40563916697728591)
  244. """
  245. code_book = np.asarray(guess)
  246. diff = np.inf
  247. prev_avg_dists = deque([diff], maxlen=2)
  248. while diff > thresh:
  249. # compute membership and distances between obs and code_book
  250. obs_code, distort = vq(obs, code_book, check_finite=False)
  251. prev_avg_dists.append(distort.mean(axis=-1))
  252. # recalc code_book as centroids of associated obs
  253. code_book, has_members = _vq.update_cluster_means(obs, obs_code,
  254. code_book.shape[0])
  255. code_book = code_book[has_members]
  256. diff = prev_avg_dists[0] - prev_avg_dists[1]
  257. return code_book, prev_avg_dists[1]
  258. def kmeans(obs, k_or_guess, iter=20, thresh=1e-5, check_finite=True,
  259. *, seed=None):
  260. """
  261. Performs k-means on a set of observation vectors forming k clusters.
  262. The k-means algorithm adjusts the classification of the observations
  263. into clusters and updates the cluster centroids until the position of
  264. the centroids is stable over successive iterations. In this
  265. implementation of the algorithm, the stability of the centroids is
  266. determined by comparing the absolute value of the change in the average
  267. Euclidean distance between the observations and their corresponding
  268. centroids against a threshold. This yields
  269. a code book mapping centroids to codes and vice versa.
  270. Parameters
  271. ----------
  272. obs : ndarray
  273. Each row of the M by N array is an observation vector. The
  274. columns are the features seen during each observation.
  275. The features must be whitened first with the `whiten` function.
  276. k_or_guess : int or ndarray
  277. The number of centroids to generate. A code is assigned to
  278. each centroid, which is also the row index of the centroid
  279. in the code_book matrix generated.
  280. The initial k centroids are chosen by randomly selecting
  281. observations from the observation matrix. Alternatively,
  282. passing a k by N array specifies the initial k centroids.
  283. iter : int, optional
  284. The number of times to run k-means, returning the codebook
  285. with the lowest distortion. This argument is ignored if
  286. initial centroids are specified with an array for the
  287. ``k_or_guess`` parameter. This parameter does not represent the
  288. number of iterations of the k-means algorithm.
  289. thresh : float, optional
  290. Terminates the k-means algorithm if the change in
  291. distortion since the last k-means iteration is less than
  292. or equal to threshold.
  293. check_finite : bool, optional
  294. Whether to check that the input matrices contain only finite numbers.
  295. Disabling may give a performance gain, but may result in problems
  296. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  297. Default: True
  298. seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
  299. Seed for initializing the pseudo-random number generator.
  300. If `seed` is None (or `numpy.random`), the `numpy.random.RandomState`
  301. singleton is used.
  302. If `seed` is an int, a new ``RandomState`` instance is used,
  303. seeded with `seed`.
  304. If `seed` is already a ``Generator`` or ``RandomState`` instance then
  305. that instance is used.
  306. The default is None.
  307. Returns
  308. -------
  309. codebook : ndarray
  310. A k by N array of k centroids. The ith centroid
  311. codebook[i] is represented with the code i. The centroids
  312. and codes generated represent the lowest distortion seen,
  313. not necessarily the globally minimal distortion.
  314. Note that the number of centroids is not necessarily the same as the
  315. ``k_or_guess`` parameter, because centroids assigned to no observations
  316. are removed during iterations.
  317. distortion : float
  318. The mean (non-squared) Euclidean distance between the observations
  319. passed and the centroids generated. Note the difference to the standard
  320. definition of distortion in the context of the k-means algorithm, which
  321. is the sum of the squared distances.
  322. See Also
  323. --------
  324. kmeans2 : a different implementation of k-means clustering
  325. with more methods for generating initial centroids but without
  326. using a distortion change threshold as a stopping criterion.
  327. whiten : must be called prior to passing an observation matrix
  328. to kmeans.
  329. Notes
  330. -----
  331. For more functionalities or optimal performance, you can use
  332. `sklearn.cluster.KMeans <https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html>`_.
  333. `This <https://hdbscan.readthedocs.io/en/latest/performance_and_scalability.html#comparison-of-high-performance-implementations>`_
  334. is a benchmark result of several implementations.
  335. Examples
  336. --------
  337. >>> import numpy as np
  338. >>> from scipy.cluster.vq import vq, kmeans, whiten
  339. >>> import matplotlib.pyplot as plt
  340. >>> features = np.array([[ 1.9,2.3],
  341. ... [ 1.5,2.5],
  342. ... [ 0.8,0.6],
  343. ... [ 0.4,1.8],
  344. ... [ 0.1,0.1],
  345. ... [ 0.2,1.8],
  346. ... [ 2.0,0.5],
  347. ... [ 0.3,1.5],
  348. ... [ 1.0,1.0]])
  349. >>> whitened = whiten(features)
  350. >>> book = np.array((whitened[0],whitened[2]))
  351. >>> kmeans(whitened,book)
  352. (array([[ 2.3110306 , 2.86287398], # random
  353. [ 0.93218041, 1.24398691]]), 0.85684700941625547)
  354. >>> codes = 3
  355. >>> kmeans(whitened,codes)
  356. (array([[ 2.3110306 , 2.86287398], # random
  357. [ 1.32544402, 0.65607529],
  358. [ 0.40782893, 2.02786907]]), 0.5196582527686241)
  359. >>> # Create 50 datapoints in two clusters a and b
  360. >>> pts = 50
  361. >>> rng = np.random.default_rng()
  362. >>> a = rng.multivariate_normal([0, 0], [[4, 1], [1, 4]], size=pts)
  363. >>> b = rng.multivariate_normal([30, 10],
  364. ... [[10, 2], [2, 1]],
  365. ... size=pts)
  366. >>> features = np.concatenate((a, b))
  367. >>> # Whiten data
  368. >>> whitened = whiten(features)
  369. >>> # Find 2 clusters in the data
  370. >>> codebook, distortion = kmeans(whitened, 2)
  371. >>> # Plot whitened data and cluster centers in red
  372. >>> plt.scatter(whitened[:, 0], whitened[:, 1])
  373. >>> plt.scatter(codebook[:, 0], codebook[:, 1], c='r')
  374. >>> plt.show()
  375. """
  376. obs = _asarray_validated(obs, check_finite=check_finite)
  377. if iter < 1:
  378. raise ValueError("iter must be at least 1, got %s" % iter)
  379. # Determine whether a count (scalar) or an initial guess (array) was passed.
  380. if not np.isscalar(k_or_guess):
  381. guess = _asarray_validated(k_or_guess, check_finite=check_finite)
  382. if guess.size < 1:
  383. raise ValueError("Asked for 0 clusters. Initial book was %s" %
  384. guess)
  385. return _kmeans(obs, guess, thresh=thresh)
  386. # k_or_guess is a scalar, now verify that it's an integer
  387. k = int(k_or_guess)
  388. if k != k_or_guess:
  389. raise ValueError("If k_or_guess is a scalar, it must be an integer.")
  390. if k < 1:
  391. raise ValueError("Asked for %d clusters." % k)
  392. rng = check_random_state(seed)
  393. # initialize best distance value to a large value
  394. best_dist = np.inf
  395. for i in range(iter):
  396. # the initial code book is randomly selected from observations
  397. guess = _kpoints(obs, k, rng)
  398. book, dist = _kmeans(obs, guess, thresh=thresh)
  399. if dist < best_dist:
  400. best_book = book
  401. best_dist = dist
  402. return best_book, best_dist
  403. def _kpoints(data, k, rng):
  404. """Pick k points at random in data (one row = one observation).
  405. Parameters
  406. ----------
  407. data : ndarray
  408. Expect a rank 1 or 2 array. Rank 1 are assumed to describe one
  409. dimensional data, rank 2 multidimensional data, in which case one
  410. row is one observation.
  411. k : int
  412. Number of samples to generate.
  413. rng : `numpy.random.Generator` or `numpy.random.RandomState`
  414. Random number generator.
  415. Returns
  416. -------
  417. x : ndarray
  418. A 'k' by 'N' containing the initial centroids
  419. """
  420. idx = rng.choice(data.shape[0], size=k, replace=False)
  421. return data[idx]
  422. def _krandinit(data, k, rng):
  423. """Returns k samples of a random variable whose parameters depend on data.
  424. More precisely, it returns k observations sampled from a Gaussian random
  425. variable whose mean and covariances are the ones estimated from the data.
  426. Parameters
  427. ----------
  428. data : ndarray
  429. Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D
  430. data, rank 2 multidimensional data, in which case one
  431. row is one observation.
  432. k : int
  433. Number of samples to generate.
  434. rng : `numpy.random.Generator` or `numpy.random.RandomState`
  435. Random number generator.
  436. Returns
  437. -------
  438. x : ndarray
  439. A 'k' by 'N' containing the initial centroids
  440. """
  441. mu = data.mean(axis=0)
  442. if data.ndim == 1:
  443. cov = np.cov(data)
  444. x = rng.standard_normal(size=k)
  445. x *= np.sqrt(cov)
  446. elif data.shape[1] > data.shape[0]:
  447. # initialize when the covariance matrix is rank deficient
  448. _, s, vh = np.linalg.svd(data - mu, full_matrices=False)
  449. x = rng.standard_normal(size=(k, s.size))
  450. sVh = s[:, None] * vh / np.sqrt(data.shape[0] - 1)
  451. x = x.dot(sVh)
  452. else:
  453. cov = np.atleast_2d(np.cov(data, rowvar=False))
  454. # k rows, d cols (one row = one obs)
  455. # Generate k sample of a random variable ~ Gaussian(mu, cov)
  456. x = rng.standard_normal(size=(k, mu.size))
  457. x = x.dot(np.linalg.cholesky(cov).T)
  458. x += mu
  459. return x
  460. def _kpp(data, k, rng):
  461. """ Picks k points in the data based on the kmeans++ method.
  462. Parameters
  463. ----------
  464. data : ndarray
  465. Expect a rank 1 or 2 array. Rank 1 is assumed to describe 1-D
  466. data, rank 2 multidimensional data, in which case one
  467. row is one observation.
  468. k : int
  469. Number of samples to generate.
  470. rng : `numpy.random.Generator` or `numpy.random.RandomState`
  471. Random number generator.
  472. Returns
  473. -------
  474. init : ndarray
  475. A 'k' by 'N' containing the initial centroids.
  476. References
  477. ----------
  478. .. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
  479. careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
  480. on Discrete Algorithms, 2007.
  481. """
  482. dims = data.shape[1] if len(data.shape) > 1 else 1
  483. init = np.ndarray((k, dims))
  484. for i in range(k):
  485. if i == 0:
  486. init[i, :] = data[rng_integers(rng, data.shape[0])]
  487. else:
  488. D2 = cdist(init[:i,:], data, metric='sqeuclidean').min(axis=0)
  489. probs = D2/D2.sum()
  490. cumprobs = probs.cumsum()
  491. r = rng.uniform()
  492. init[i, :] = data[np.searchsorted(cumprobs, r)]
  493. return init
  494. _valid_init_meth = {'random': _krandinit, 'points': _kpoints, '++': _kpp}
  495. def _missing_warn():
  496. """Print a warning when called."""
  497. warnings.warn("One of the clusters is empty. "
  498. "Re-run kmeans with a different initialization.")
  499. def _missing_raise():
  500. """Raise a ClusterError when called."""
  501. raise ClusterError("One of the clusters is empty. "
  502. "Re-run kmeans with a different initialization.")
  503. _valid_miss_meth = {'warn': _missing_warn, 'raise': _missing_raise}
  504. def kmeans2(data, k, iter=10, thresh=1e-5, minit='random',
  505. missing='warn', check_finite=True, *, seed=None):
  506. """
  507. Classify a set of observations into k clusters using the k-means algorithm.
  508. The algorithm attempts to minimize the Euclidean distance between
  509. observations and centroids. Several initialization methods are
  510. included.
  511. Parameters
  512. ----------
  513. data : ndarray
  514. A 'M' by 'N' array of 'M' observations in 'N' dimensions or a length
  515. 'M' array of 'M' 1-D observations.
  516. k : int or ndarray
  517. The number of clusters to form as well as the number of
  518. centroids to generate. If `minit` initialization string is
  519. 'matrix', or if a ndarray is given instead, it is
  520. interpreted as initial cluster to use instead.
  521. iter : int, optional
  522. Number of iterations of the k-means algorithm to run. Note
  523. that this differs in meaning from the iters parameter to
  524. the kmeans function.
  525. thresh : float, optional
  526. (not used yet)
  527. minit : str, optional
  528. Method for initialization. Available methods are 'random',
  529. 'points', '++' and 'matrix':
  530. 'random': generate k centroids from a Gaussian with mean and
  531. variance estimated from the data.
  532. 'points': choose k observations (rows) at random from data for
  533. the initial centroids.
  534. '++': choose k observations accordingly to the kmeans++ method
  535. (careful seeding)
  536. 'matrix': interpret the k parameter as a k by M (or length k
  537. array for 1-D data) array of initial centroids.
  538. missing : str, optional
  539. Method to deal with empty clusters. Available methods are
  540. 'warn' and 'raise':
  541. 'warn': give a warning and continue.
  542. 'raise': raise an ClusterError and terminate the algorithm.
  543. check_finite : bool, optional
  544. Whether to check that the input matrices contain only finite numbers.
  545. Disabling may give a performance gain, but may result in problems
  546. (crashes, non-termination) if the inputs do contain infinities or NaNs.
  547. Default: True
  548. seed : {None, int, `numpy.random.Generator`, `numpy.random.RandomState`}, optional
  549. Seed for initializing the pseudo-random number generator.
  550. If `seed` is None (or `numpy.random`), the `numpy.random.RandomState`
  551. singleton is used.
  552. If `seed` is an int, a new ``RandomState`` instance is used,
  553. seeded with `seed`.
  554. If `seed` is already a ``Generator`` or ``RandomState`` instance then
  555. that instance is used.
  556. The default is None.
  557. Returns
  558. -------
  559. centroid : ndarray
  560. A 'k' by 'N' array of centroids found at the last iteration of
  561. k-means.
  562. label : ndarray
  563. label[i] is the code or index of the centroid the
  564. ith observation is closest to.
  565. See Also
  566. --------
  567. kmeans
  568. References
  569. ----------
  570. .. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
  571. careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
  572. on Discrete Algorithms, 2007.
  573. Examples
  574. --------
  575. >>> from scipy.cluster.vq import kmeans2
  576. >>> import matplotlib.pyplot as plt
  577. >>> import numpy as np
  578. Create z, an array with shape (100, 2) containing a mixture of samples
  579. from three multivariate normal distributions.
  580. >>> rng = np.random.default_rng()
  581. >>> a = rng.multivariate_normal([0, 6], [[2, 1], [1, 1.5]], size=45)
  582. >>> b = rng.multivariate_normal([2, 0], [[1, -1], [-1, 3]], size=30)
  583. >>> c = rng.multivariate_normal([6, 4], [[5, 0], [0, 1.2]], size=25)
  584. >>> z = np.concatenate((a, b, c))
  585. >>> rng.shuffle(z)
  586. Compute three clusters.
  587. >>> centroid, label = kmeans2(z, 3, minit='points')
  588. >>> centroid
  589. array([[ 2.22274463, -0.61666946], # may vary
  590. [ 0.54069047, 5.86541444],
  591. [ 6.73846769, 4.01991898]])
  592. How many points are in each cluster?
  593. >>> counts = np.bincount(label)
  594. >>> counts
  595. array([29, 51, 20]) # may vary
  596. Plot the clusters.
  597. >>> w0 = z[label == 0]
  598. >>> w1 = z[label == 1]
  599. >>> w2 = z[label == 2]
  600. >>> plt.plot(w0[:, 0], w0[:, 1], 'o', alpha=0.5, label='cluster 0')
  601. >>> plt.plot(w1[:, 0], w1[:, 1], 'd', alpha=0.5, label='cluster 1')
  602. >>> plt.plot(w2[:, 0], w2[:, 1], 's', alpha=0.5, label='cluster 2')
  603. >>> plt.plot(centroid[:, 0], centroid[:, 1], 'k*', label='centroids')
  604. >>> plt.axis('equal')
  605. >>> plt.legend(shadow=True)
  606. >>> plt.show()
  607. """
  608. if int(iter) < 1:
  609. raise ValueError("Invalid iter (%s), "
  610. "must be a positive integer." % iter)
  611. try:
  612. miss_meth = _valid_miss_meth[missing]
  613. except KeyError as e:
  614. raise ValueError("Unknown missing method %r" % (missing,)) from e
  615. data = _asarray_validated(data, check_finite=check_finite)
  616. if data.ndim == 1:
  617. d = 1
  618. elif data.ndim == 2:
  619. d = data.shape[1]
  620. else:
  621. raise ValueError("Input of rank > 2 is not supported.")
  622. if data.size < 1:
  623. raise ValueError("Empty input is not supported.")
  624. # If k is not a single value, it should be compatible with data's shape
  625. if minit == 'matrix' or not np.isscalar(k):
  626. code_book = np.array(k, copy=True)
  627. if data.ndim != code_book.ndim:
  628. raise ValueError("k array doesn't match data rank")
  629. nc = len(code_book)
  630. if data.ndim > 1 and code_book.shape[1] != d:
  631. raise ValueError("k array doesn't match data dimension")
  632. else:
  633. nc = int(k)
  634. if nc < 1:
  635. raise ValueError("Cannot ask kmeans2 for %d clusters"
  636. " (k was %s)" % (nc, k))
  637. elif nc != k:
  638. warnings.warn("k was not an integer, was converted.")
  639. try:
  640. init_meth = _valid_init_meth[minit]
  641. except KeyError as e:
  642. raise ValueError("Unknown init method %r" % (minit,)) from e
  643. else:
  644. rng = check_random_state(seed)
  645. code_book = init_meth(data, k, rng)
  646. for i in range(iter):
  647. # Compute the nearest neighbor for each obs using the current code book
  648. label = vq(data, code_book)[0]
  649. # Update the code book by computing centroids
  650. new_code_book, has_members = _vq.update_cluster_means(data, label, nc)
  651. if not has_members.all():
  652. miss_meth()
  653. # Set the empty clusters to their previous positions
  654. new_code_book[~has_members] = code_book[~has_members]
  655. code_book = new_code_book
  656. return code_book, label