|
8 | 8 | # Author: Ron Weiss <ronweiss@gmail.com>
|
9 | 9 | # Fabian Pedregosa <fabian.pedregosa@inria.fr>
|
10 | 10 | # Bertrand Thirion <bertrand.thirion@inria.fr>
|
11 |
| -# Wei Xue <xuewei4d@gmail.com> |
12 | 11 |
|
13 | 12 | import numpy as np
|
| 13 | +import warnings |
14 | 14 | from scipy import linalg
|
15 | 15 | from time import time
|
16 | 16 | from abc import ABCMeta, abstractmethod
|
17 | 17 |
|
18 |
| -from ..base import BaseEstimator, DensityMixin |
19 |
| -from ..utils import check_random_state, check_array |
| 18 | +from .base import _MixtureBase |
| 19 | +from ..utils import check_random_state, check_array, deprecated |
20 | 20 | from ..utils.extmath import logsumexp
|
21 | 21 | from ..utils.validation import check_is_fitted
|
22 | 22 | from ..externals import six
|
@@ -114,129 +114,7 @@ def sample_gaussian(mean, covar, covariance_type='diag', n_samples=1,
|
114 | 114 | return (rand.T + mean).T
|
115 | 115 |
|
116 | 116 |
|
117 |
| -def _check_X(X, n_components=None, n_features=None): |
118 |
| - """Check the input data X. |
119 |
| -
|
120 |
| - Parameters |
121 |
| - ---------- |
122 |
| - X : array-like, (n_samples, n_features) |
123 |
| -
|
124 |
| - n_components : int |
125 |
| -
|
126 |
| - Returns |
127 |
| - ------- |
128 |
| - X : array, (n_samples, n_features) |
129 |
| - """ |
130 |
| - X = check_array(X, dtype=[np.float64, np.float32]) |
131 |
| - if n_components is not None and X.shape[0] < n_components: |
132 |
| - raise ValueError('Expected n_samples >= n_components' |
133 |
| - 'but got n_components = %d, n_samples = %d' |
134 |
| - % (n_components, X.shape[0])) |
135 |
| - if n_features is not None and X.shape[1] != n_features: |
136 |
| - raise ValueError("Expected the input data X have %d features, " |
137 |
| - "but got %d features" |
138 |
| - % (n_features, X.shape[1])) |
139 |
| - return X |
140 |
| - |
141 |
| - |
142 |
| -class _MixtureBase(six.with_metaclass(ABCMeta, DensityMixin, BaseEstimator)): |
143 |
| - """Base class for mixture models. |
144 |
| -
|
145 |
| - This abstract class specifies the interface by abstract methods and |
146 |
| - provides basic common methods for mixture models. |
147 |
| - """ |
148 |
| - def __init__(self, n_components, covariance_type, tol, reg_covar, |
149 |
| - max_iter, n_init, init_params, random_state, verbose): |
150 |
| - self.n_components = n_components |
151 |
| - self.covariance_type = covariance_type |
152 |
| - self.tol = tol |
153 |
| - self.reg_covar = reg_covar |
154 |
| - self.random_state = random_state |
155 |
| - self.max_iter = max_iter |
156 |
| - self.n_init = n_init |
157 |
| - self.init_params = init_params |
158 |
| - self.verbose = verbose |
159 |
| - |
160 |
| - self.n_features_ = None |
161 |
| - self.converged_ = False |
162 |
| - |
163 |
| - def _check_parameters(self): |
164 |
| - if self.n_init < 1: |
165 |
| - raise ValueError("Invalid value for 'n_init': %d " |
166 |
| - "Estimation requires at least one run" |
167 |
| - % self.n_init) |
168 |
| - |
169 |
| - if self.covariance_type not in ['spherical', 'tied', 'diag', 'full']: |
170 |
| - raise ValueError("Invalid value for 'covariance_type': %s " |
171 |
| - "'covariance_type' should be in " |
172 |
| - "['spherical', 'tied', 'diag', 'full']" |
173 |
| - % self.covariance_type) |
174 |
| - |
175 |
| - @abstractmethod |
176 |
| - def _m_step(self, X, resp): |
177 |
| - """M step |
178 |
| -
|
179 |
| - Parameters |
180 |
| - ---------- |
181 |
| - X : array-like, shape = (n_samples, n_features) |
182 |
| - """ |
183 |
| - |
184 |
| - @abstractmethod |
185 |
| - def _e_step(self, X): |
186 |
| - """E Step |
187 |
| -
|
188 |
| - Parameters |
189 |
| - ---------- |
190 |
| - X : array-like, shape = (n_samples, n_features) |
191 |
| -
|
192 |
| - Returns |
193 |
| - ------- |
194 |
| - log_likelihood : float |
195 |
| -
|
196 |
| - resp : array-like, shape = (n_samples, n_components) |
197 |
| - """ |
198 |
| - |
199 |
| - def fit(self, X, y=None): |
200 |
| - """Estimate model parameters with the EM algorithm. |
201 |
| -
|
202 |
| - A initialization step is performed before entering the |
203 |
| - expectation-maximization (EM) algorithm. If you want to avoid |
204 |
| - this step, set the keyword argument init_params to the empty |
205 |
| - string '' when creating the GMM object. Likewise, if you would |
206 |
| - like just to do an initialization, set max_iter=0. |
207 |
| -
|
208 |
| - Parameters |
209 |
| - ---------- |
210 |
| - X : array_like, shape (n, n_features) |
211 |
| - List of n_features-dimensional data points. Each row |
212 |
| - corresponds to a single data point. |
213 |
| -
|
214 |
| - Returns |
215 |
| - ------- |
216 |
| - self |
217 |
| - """ |
218 |
| - X = _check_X(X, self.n_components) |
219 |
| - self._check_parameters() |
220 |
| - self.n_features_ = X.shape[1] |
221 |
| - self._fit(X, y) |
222 |
| - return self |
223 |
| - |
224 |
| - def _n_parameters(self): |
225 |
| - """Return the number of free parameters in the model.""" |
226 |
| - ndim = self.means_.shape[1] |
227 |
| - if self.covariance_type == 'full': |
228 |
| - cov_params = self.n_components * ndim * (ndim + 1) / 2. |
229 |
| - elif self.covariance_type == 'diag': |
230 |
| - cov_params = self.n_components * ndim |
231 |
| - elif self.covariance_type == 'tied': |
232 |
| - cov_params = ndim * (ndim + 1) / 2. |
233 |
| - elif self.covariance_type == 'spherical': |
234 |
| - cov_params = self.n_components |
235 |
| - mean_params = ndim * self.n_components |
236 |
| - return int(cov_params + mean_params + self.n_components - 1) |
237 |
| - |
238 |
| - |
239 |
| -class GMM(_MixtureBase): |
| 117 | +class _GMMBase(_MixtureBase): |
240 | 118 | """Gaussian Mixture Model.
|
241 | 119 |
|
242 | 120 | Representation of a Gaussian mixture model probability distribution.
|
@@ -362,7 +240,7 @@ def __init__(self, n_components=1, covariance_type='diag',
|
362 | 240 | random_state=None, tol=1e-3, min_covar=1e-3,
|
363 | 241 | n_iter=100, n_init=1, params='wmc', init_params='wmc',
|
364 | 242 | verbose=0):
|
365 |
| - super(GMM, self).__init__( |
| 243 | + super(_GMMBase, self).__init__( |
366 | 244 | n_components=n_components, covariance_type=covariance_type,
|
367 | 245 | random_state=random_state, tol=tol, reg_covar=min_covar,
|
368 | 246 | max_iter=n_iter, n_init=n_init, init_params=init_params,
|
@@ -583,7 +461,15 @@ def _fit(self, X, y=None, do_prediction=False):
|
583 | 461 | Posterior probabilities of each mixture component for each
|
584 | 462 | observation.
|
585 | 463 | """
|
| 464 | + |
586 | 465 | # initialization step
|
| 466 | + X = check_array(X, dtype=np.float64, ensure_min_samples=2, |
| 467 | + estimator=self) |
| 468 | + if X.shape[0] < self.n_components: |
| 469 | + raise ValueError( |
| 470 | + 'GMM estimation with %s components, but got only %s samples' % |
| 471 | + (self.n_components, X.shape[0])) |
| 472 | + |
587 | 473 | max_log_prob = -np.infty
|
588 | 474 |
|
589 | 475 | if self.verbose > 0:
|
@@ -732,6 +618,24 @@ def aic(self, X):
|
732 | 618 | """
|
733 | 619 | return - 2 * self.score(X).sum() + 2 * self._n_parameters()
|
734 | 620 |
|
| 621 | + |
| 622 | +@deprecated("The class GMM is deprecated and " |
| 623 | + "will be removed in 0.20. Use class GaussianMixture instead") |
| 624 | +class GMM(_GMMBase): |
| 625 | + def __init__(self, n_components=1, covariance_type='diag', |
| 626 | + random_state=None, tol=1e-3, min_covar=1e-3, |
| 627 | + n_iter=100, n_init=1, params='wmc', init_params='wmc', |
| 628 | + verbose=0): |
| 629 | + super(GMM, self).__init__( |
| 630 | + n_components=n_components, covariance_type=covariance_type, |
| 631 | + random_state=random_state, tol=tol, min_covar=min_covar, |
| 632 | + n_iter=n_iter, n_init=n_init, params=params, |
| 633 | + init_params=init_params, verbose=verbose) |
| 634 | + |
| 635 | + warnings.warn( |
| 636 | + "The class GMM is deprecated and " |
| 637 | + "will be removed in 0.20. Use class GaussianMixture instead") |
| 638 | + |
735 | 639 | #########################################################################
|
736 | 640 | # some helper routines
|
737 | 641 | #########################################################################
|
|
0 commit comments