Skip to content

Commit a1cea27

Browse files
committed
Move _MixtureBase into base.py.
Add _GMMBase and deprecate the old GMM class. [doc skip]
1 parent 7c7a1f2 commit a1cea27

File tree

4 files changed

+165
-130
lines changed

4 files changed

+165
-130
lines changed

sklearn/mixture/base.py

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
"""Base class for mixture models."""
2+
3+
# Author: Wei Xue <xuewei4d@gmail.com>
4+
5+
import numpy as np
6+
from abc import ABCMeta, abstractmethod
7+
from ..base import BaseEstimator, DensityMixin
8+
from ..utils import check_array
9+
from ..externals import six
10+
11+
12+
def _check_X(X, n_components=None, n_features=None):
13+
"""Check the input data X.
14+
15+
Parameters
16+
----------
17+
X : array-like, (n_samples, n_features)
18+
19+
n_components : int
20+
21+
Returns
22+
-------
23+
X : array, (n_samples, n_features)
24+
"""
25+
X = check_array(X, dtype=[np.float64, np.float32])
26+
if n_components is not None and X.shape[0] < n_components:
27+
raise ValueError('Expected n_samples >= n_components'
28+
'but got n_components = %d, n_samples = %d'
29+
% (n_components, X.shape[0]))
30+
if n_features is not None and X.shape[1] != n_features:
31+
raise ValueError("Expected the input data X have %d features, "
32+
"but got %d features"
33+
% (n_features, X.shape[1]))
34+
return X
35+
36+
37+
class _MixtureBase(six.with_metaclass(ABCMeta, DensityMixin, BaseEstimator)):
38+
"""Base class for mixture models.
39+
40+
This abstract class specifies the interface by abstract methods and
41+
provides basic common methods for mixture models.
42+
"""
43+
def __init__(self, n_components, covariance_type, tol, reg_covar,
44+
max_iter, n_init, init_params, random_state, verbose):
45+
self.n_components = n_components
46+
self.covariance_type = covariance_type
47+
self.tol = tol
48+
self.reg_covar = reg_covar
49+
self.random_state = random_state
50+
self.max_iter = max_iter
51+
self.n_init = n_init
52+
self.init_params = init_params
53+
self.verbose = verbose
54+
55+
self.n_features_ = None
56+
self.converged_ = False
57+
58+
def _check_parameters(self):
59+
if self.n_init < 1:
60+
raise ValueError("Invalid value for 'n_init': %d "
61+
"Estimation requires at least one run"
62+
% self.n_init)
63+
64+
if self.covariance_type not in ['spherical', 'tied', 'diag', 'full']:
65+
raise ValueError("Invalid value for 'covariance_type': %s "
66+
"'covariance_type' should be in "
67+
"['spherical', 'tied', 'diag', 'full']"
68+
% self.covariance_type)
69+
70+
@abstractmethod
71+
def _m_step(self, X, resp):
72+
"""M step
73+
74+
Parameters
75+
----------
76+
X : array-like, shape = (n_samples, n_features)
77+
"""
78+
79+
@abstractmethod
80+
def _e_step(self, X):
81+
"""E Step
82+
83+
Parameters
84+
----------
85+
X : array-like, shape = (n_samples, n_features)
86+
87+
Returns
88+
-------
89+
log_likelihood : float
90+
91+
resp : array-like, shape = (n_samples, n_components)
92+
"""
93+
94+
def fit(self, X, y=None):
95+
"""Estimate model parameters with the EM algorithm.
96+
97+
A initialization step is performed before entering the
98+
expectation-maximization (EM) algorithm. If you want to avoid
99+
this step, set the keyword argument init_params to the empty
100+
string '' when creating the GMM object. Likewise, if you would
101+
like just to do an initialization, set max_iter=0.
102+
103+
Parameters
104+
----------
105+
X : array_like, shape (n, n_features)
106+
List of n_features-dimensional data points. Each row
107+
corresponds to a single data point.
108+
109+
Returns
110+
-------
111+
self
112+
"""
113+
X = _check_X(X, self.n_components)
114+
self._check_parameters()
115+
self.n_features_ = X.shape[1]
116+
self._fit(X, y)
117+
return self
118+
119+
def _n_parameters(self):
120+
"""Return the number of free parameters in the model."""
121+
ndim = self.means_.shape[1]
122+
if self.covariance_type == 'full':
123+
cov_params = self.n_components * ndim * (ndim + 1) / 2.
124+
elif self.covariance_type == 'diag':
125+
cov_params = self.n_components * ndim
126+
elif self.covariance_type == 'tied':
127+
cov_params = ndim * (ndim + 1) / 2.
128+
elif self.covariance_type == 'spherical':
129+
cov_params = self.n_components
130+
mean_params = ndim * self.n_components
131+
return int(cov_params + mean_params + self.n_components - 1)

sklearn/mixture/dpgmm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..utils.extmath import logsumexp, pinvh, squared_norm
2121
from ..utils.validation import check_is_fitted
2222
from .. import cluster
23-
from .gmm import GMM
23+
from .gmm import _GMMBase
2424

2525

2626
def digamma(x):
@@ -106,7 +106,7 @@ def _bound_state_log_lik(X, initial_bound, precs, means, covariance_type):
106106
return bound
107107

108108

109-
class DPGMM(GMM):
109+
class DPGMM(_GMMBase):
110110
"""Variational Inference for the Infinite Gaussian Mixture Model.
111111
112112
DPGMM stands for Dirichlet Process Gaussian Mixture Model, and it

sklearn/mixture/gmm.py

+31-127
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,15 @@
88
# Author: Ron Weiss <ronweiss@gmail.com>
99
# Fabian Pedregosa <fabian.pedregosa@inria.fr>
1010
# Bertrand Thirion <bertrand.thirion@inria.fr>
11-
# Wei Xue <xuewei4d@gmail.com>
1211

1312
import numpy as np
13+
import warnings
1414
from scipy import linalg
1515
from time import time
1616
from abc import ABCMeta, abstractmethod
1717

18-
from ..base import BaseEstimator, DensityMixin
19-
from ..utils import check_random_state, check_array
18+
from .base import _MixtureBase
19+
from ..utils import check_random_state, check_array, deprecated
2020
from ..utils.extmath import logsumexp
2121
from ..utils.validation import check_is_fitted
2222
from ..externals import six
@@ -114,129 +114,7 @@ def sample_gaussian(mean, covar, covariance_type='diag', n_samples=1,
114114
return (rand.T + mean).T
115115

116116

117-
def _check_X(X, n_components=None, n_features=None):
118-
"""Check the input data X.
119-
120-
Parameters
121-
----------
122-
X : array-like, (n_samples, n_features)
123-
124-
n_components : int
125-
126-
Returns
127-
-------
128-
X : array, (n_samples, n_features)
129-
"""
130-
X = check_array(X, dtype=[np.float64, np.float32])
131-
if n_components is not None and X.shape[0] < n_components:
132-
raise ValueError('Expected n_samples >= n_components'
133-
'but got n_components = %d, n_samples = %d'
134-
% (n_components, X.shape[0]))
135-
if n_features is not None and X.shape[1] != n_features:
136-
raise ValueError("Expected the input data X have %d features, "
137-
"but got %d features"
138-
% (n_features, X.shape[1]))
139-
return X
140-
141-
142-
class _MixtureBase(six.with_metaclass(ABCMeta, DensityMixin, BaseEstimator)):
143-
"""Base class for mixture models.
144-
145-
This abstract class specifies the interface by abstract methods and
146-
provides basic common methods for mixture models.
147-
"""
148-
def __init__(self, n_components, covariance_type, tol, reg_covar,
149-
max_iter, n_init, init_params, random_state, verbose):
150-
self.n_components = n_components
151-
self.covariance_type = covariance_type
152-
self.tol = tol
153-
self.reg_covar = reg_covar
154-
self.random_state = random_state
155-
self.max_iter = max_iter
156-
self.n_init = n_init
157-
self.init_params = init_params
158-
self.verbose = verbose
159-
160-
self.n_features_ = None
161-
self.converged_ = False
162-
163-
def _check_parameters(self):
164-
if self.n_init < 1:
165-
raise ValueError("Invalid value for 'n_init': %d "
166-
"Estimation requires at least one run"
167-
% self.n_init)
168-
169-
if self.covariance_type not in ['spherical', 'tied', 'diag', 'full']:
170-
raise ValueError("Invalid value for 'covariance_type': %s "
171-
"'covariance_type' should be in "
172-
"['spherical', 'tied', 'diag', 'full']"
173-
% self.covariance_type)
174-
175-
@abstractmethod
176-
def _m_step(self, X, resp):
177-
"""M step
178-
179-
Parameters
180-
----------
181-
X : array-like, shape = (n_samples, n_features)
182-
"""
183-
184-
@abstractmethod
185-
def _e_step(self, X):
186-
"""E Step
187-
188-
Parameters
189-
----------
190-
X : array-like, shape = (n_samples, n_features)
191-
192-
Returns
193-
-------
194-
log_likelihood : float
195-
196-
resp : array-like, shape = (n_samples, n_components)
197-
"""
198-
199-
def fit(self, X, y=None):
200-
"""Estimate model parameters with the EM algorithm.
201-
202-
A initialization step is performed before entering the
203-
expectation-maximization (EM) algorithm. If you want to avoid
204-
this step, set the keyword argument init_params to the empty
205-
string '' when creating the GMM object. Likewise, if you would
206-
like just to do an initialization, set max_iter=0.
207-
208-
Parameters
209-
----------
210-
X : array_like, shape (n, n_features)
211-
List of n_features-dimensional data points. Each row
212-
corresponds to a single data point.
213-
214-
Returns
215-
-------
216-
self
217-
"""
218-
X = _check_X(X, self.n_components)
219-
self._check_parameters()
220-
self.n_features_ = X.shape[1]
221-
self._fit(X, y)
222-
return self
223-
224-
def _n_parameters(self):
225-
"""Return the number of free parameters in the model."""
226-
ndim = self.means_.shape[1]
227-
if self.covariance_type == 'full':
228-
cov_params = self.n_components * ndim * (ndim + 1) / 2.
229-
elif self.covariance_type == 'diag':
230-
cov_params = self.n_components * ndim
231-
elif self.covariance_type == 'tied':
232-
cov_params = ndim * (ndim + 1) / 2.
233-
elif self.covariance_type == 'spherical':
234-
cov_params = self.n_components
235-
mean_params = ndim * self.n_components
236-
return int(cov_params + mean_params + self.n_components - 1)
237-
238-
239-
class GMM(_MixtureBase):
117+
class _GMMBase(_MixtureBase):
240118
"""Gaussian Mixture Model.
241119
242120
Representation of a Gaussian mixture model probability distribution.
@@ -362,7 +240,7 @@ def __init__(self, n_components=1, covariance_type='diag',
362240
random_state=None, tol=1e-3, min_covar=1e-3,
363241
n_iter=100, n_init=1, params='wmc', init_params='wmc',
364242
verbose=0):
365-
super(GMM, self).__init__(
243+
super(_GMMBase, self).__init__(
366244
n_components=n_components, covariance_type=covariance_type,
367245
random_state=random_state, tol=tol, reg_covar=min_covar,
368246
max_iter=n_iter, n_init=n_init, init_params=init_params,
@@ -583,7 +461,15 @@ def _fit(self, X, y=None, do_prediction=False):
583461
Posterior probabilities of each mixture component for each
584462
observation.
585463
"""
464+
586465
# initialization step
466+
X = check_array(X, dtype=np.float64, ensure_min_samples=2,
467+
estimator=self)
468+
if X.shape[0] < self.n_components:
469+
raise ValueError(
470+
'GMM estimation with %s components, but got only %s samples' %
471+
(self.n_components, X.shape[0]))
472+
587473
max_log_prob = -np.infty
588474

589475
if self.verbose > 0:
@@ -732,6 +618,24 @@ def aic(self, X):
732618
"""
733619
return - 2 * self.score(X).sum() + 2 * self._n_parameters()
734620

621+
622+
@deprecated("The class GMM is deprecated and "
623+
"will be removed in 0.20. Use class GaussianMixture instead")
624+
class GMM(_GMMBase):
625+
def __init__(self, n_components=1, covariance_type='diag',
626+
random_state=None, tol=1e-3, min_covar=1e-3,
627+
n_iter=100, n_init=1, params='wmc', init_params='wmc',
628+
verbose=0):
629+
super(GMM, self).__init__(
630+
n_components=n_components, covariance_type=covariance_type,
631+
random_state=random_state, tol=tol, min_covar=min_covar,
632+
n_iter=n_iter, n_init=n_init, params=params,
633+
init_params=init_params, verbose=verbose)
634+
635+
warnings.warn(
636+
"The class GMM is deprecated and "
637+
"will be removed in 0.20. Use class GaussianMixture instead")
638+
735639
#########################################################################
736640
# some helper routines
737641
#########################################################################

sklearn/mixture/tests/test_mixture_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def generate_data(n_samples, n_features, weights, means, covariances,
3737

3838

3939
def test_check_X():
40-
from sklearn.mixture.gmm import _check_X
40+
from sklearn.mixture.base import _check_X
4141
rng = np.random.RandomState(0)
4242
n_samples = 10
4343
n_components = 2

0 commit comments

Comments
 (0)