Skip to content

Commit ff4d8d2

Browse files
author
Cory Lorenz
committed
Add a fit_predict method for the GMM classes
With low iterations, the prediction might not be 100% accurate due to the final maximization step in the EM algorithm.
1 parent 203298e commit ff4d8d2

File tree

4 files changed

+92
-11
lines changed

4 files changed

+92
-11
lines changed

doc/whats_new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ Enhancements
2121

2222
- :class:`dummy.DummyClassifier` now supports a prior fitting strategy.
2323
By `Arnaud Joly`_.
24+
- Added a ``fit_predict`` method for :class:`mixture.GMM` and subclasses.
25+
By `Cory Lorenz`_.
2426

2527
Bug fixes
2628
.........
@@ -3413,3 +3415,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
34133415
.. _Eric Martin: http://ericmart.in
34143416

34153417
.. _Nicolas Goix: https://webperso.telecom-paristech.fr/front/frontoffice.php?SP_ID=241
3418+
3419+
.. _Cory Lorenz: https://github.com/clorenz7

sklearn/mixture/dpgmm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def _set_weights(self):
480480
+ self.gamma_[i, 2])
481481
self.weights_ /= np.sum(self.weights_)
482482

483-
def fit(self, X, y=None):
483+
def _fit(self, X, y=None):
484484
"""Estimate model parameters with the variational
485485
algorithm.
486486
@@ -595,9 +595,14 @@ def fit(self, X, y=None):
595595
# Maximization step
596596
self._do_mstep(X, z, self.params)
597597

598+
if self.n_iter == 0:
599+
# Need to make sure that there is a z value to output
600+
# Output zeros because it was just a quick initialization
601+
z = np.zeros(X.shape[0], self.n_components)
602+
598603
self._set_weights()
599604

600-
return self
605+
return z
601606

602607

603608
class VBGMM(DPGMM):

sklearn/mixture/gmm.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -411,22 +411,44 @@ def sample(self, n_samples=1, random_state=None):
411411
num_comp_in_X, random_state=random_state).T
412412
return X
413413

414-
def fit(self, X, y=None):
415-
"""Estimate model parameters with the expectation-maximization
416-
algorithm.
414+
def fit_predict(self, X, y=None):
415+
"""
416+
Fit and then predict labels for data.
417+
Warning: due to the final maximization step in the EM algorithm,
418+
with low iterations the prediction may not be 100% accurate
417419
418-
A initialization step is performed before entering the em
419-
algorithm. If you want to avoid this step, set the keyword
420-
argument init_params to the empty string '' when creating the
421-
GMM object. Likewise, if you would like just to do an
422-
initialization, set n_iter=0.
420+
Parameters
421+
----------
422+
X : array-like, shape = [n_samples, n_features]
423+
424+
Returns
425+
-------
426+
C : array, shape = (n_samples,)
427+
"""
428+
return self._fit(X, y).argmax(axis=1)
429+
430+
def _fit(self, X, y=None, do_prediction=False):
431+
"""Estimate model parameters with the EM algorithm.
432+
433+
A initialization step is performed before entering the
434+
expectation-maximization (EM) algorithm. If you want to avoid
435+
this step, set the keyword argument init_params to the empty
436+
string '' when creating the GMM object. Likewise, if you would
437+
like just to do an initialization, set n_iter=0.
423438
424439
Parameters
425440
----------
426441
X : array_like, shape (n, n_features)
427442
List of n_features-dimensional data points. Each row
428443
corresponds to a single data point.
444+
445+
Returns
446+
-------
447+
responsibilities : array, shape (n_samples, n_components)
448+
Posterior probabilities of each mixture component for each
449+
observation
429450
"""
451+
430452
# initialization step
431453
X = check_array(X, dtype=np.float64)
432454
if X.shape[0] < self.n_components:
@@ -501,10 +523,33 @@ def fit(self, X, y=None):
501523
self.covars_ = best_params['covars']
502524
self.means_ = best_params['means']
503525
self.weights_ = best_params['weights']
526+
else:
527+
# Need to make sure that there are responsibilities to output
528+
# Output zeros because it was just a quick initialization
529+
responsibilities = np.zeros(X.shape[0], self.n_components)
530+
531+
return responsibilities
532+
533+
def fit(self, X, y=None):
534+
"""Estimate model parameters with the EM algorithm.
535+
536+
A initialization step is performed before entering the
537+
expectation-maximization (EM) algorithm. If you want to avoid
538+
this step, set the keyword argument init_params to the empty
539+
string '' when creating the GMM object. Likewise, if you would
540+
like just to do an initialization, set n_iter=0.
541+
542+
Parameters
543+
----------
544+
X : array_like, shape (n, n_features)
545+
List of n_features-dimensional data points. Each row
546+
corresponds to a single data point.
547+
"""
548+
self._fit(X, y)
504549
return self
505550

506551
def _do_mstep(self, X, responsibilities, params, min_covar=0):
507-
""" Perform the Mstep of the EM algorithm and return the class weihgts.
552+
""" Perform the Mstep of the EM algorithm and return the class weights
508553
"""
509554
weights = responsibilities.sum(axis=0)
510555
weighted_X_sum = np.dot(responsibilities.T, X)

sklearn/mixture/tests/test_gmm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sklearn import mixture
99
from sklearn.datasets.samples_generator import make_spd_matrix
1010
from sklearn.utils.testing import assert_greater
11+
from sklearn.metrics.cluster import adjusted_rand_score
1112

1213
rng = np.random.RandomState(0)
1314

@@ -327,6 +328,32 @@ def test_1d_1component():
327328
assert_array_almost_equal(g.bic(X), g_full_bic)
328329

329330

331+
def assert_fit_predict_correct(model, X):
332+
333+
predictions_1 = model.fit(X).predict(X)
334+
predictions_2 = model.fit_predict(X)
335+
336+
assert adjusted_rand_score(predictions_1, predictions_2) == 1.0
337+
338+
339+
def test_fit_predict():
340+
"""
341+
test that gmm.fit_predict is equivalent to gmm.fit + gmm.predict
342+
"""
343+
lrng = np.random.RandomState(0)
344+
345+
n_samples, n_dim, n_comps = 100, 2, 2
346+
mu = np.array([[8, 8]])
347+
component_0 = lrng.randn(n_samples, n_dim)
348+
component_1 = lrng.randn(n_samples, n_dim) + mu
349+
X = np.vstack((component_0, component_1))
350+
351+
for m_constructor in (mixture.GMM, mixture.VBGMM, mixture.DPGMM):
352+
model = m_constructor(n_components=n_comps, covariance_type='full',
353+
min_covar=1e-7, n_iter=5, random_state=lrng)
354+
assert_fit_predict_correct(model, X)
355+
356+
330357
def test_aic():
331358
# Test the aic and bic criteria
332359
n_samples, n_dim, n_components = 50, 3, 2

0 commit comments

Comments
 (0)