Skip to content

[MRG+1] Adding a fit_predict method for the GMM #4593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2015

Conversation

clorenz7
Copy link

With low iterations, the prediction might not be 100% accurate due to
the final maximization step in the EM algorithm.

See issue:
#4579

self.fit(X, y)

if self.responsibilities_ is None:
raise RuntimeError("Fitting failed, cannot predict")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to use the _check_fitted_model: do a git grep _check_fitted_model to see examples in the code base.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually this comment is no longer relevant in light of the other comments.


responsibilities = self._fit(X, y)
if responsibilities is None:
prediction = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How could that ever happen?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would return directly:

return self._fit(X, y).argmax(axis=1)

and make sure that _fit can never return None (make it raise a ValueError or similar with a meaningful error message otherwise).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could happen in the case when n_iter == 0. You're right that never returning None is better, so I added a check to run score_samples to get the correct output value in that case (that apparently happens when running an HMM). My current idea is just to output zeros because the idea of n_iter=0 seems to be to quickly initialize a model.

@ogrisel
Copy link
Member

ogrisel commented Apr 14, 2015

I am not sure I understand the travis failures, it probably requires to launch a debugger.

@clorenz7
Copy link
Author

@ogrisel Besides addressing your comments, I changed the GMM subclass fit method to _fit, and added some additional test cases.

@ogrisel
Copy link
Member

ogrisel commented Apr 15, 2015

For the travis failure, a solution would be to not implement fit_predict for DPGMM and VBGMM, it is possible to introduce a new _BaseGMM abstract base class with most of the current methods of GMM in it and then make GMM, ``DPGMMandVBGMM`. Finally only implement `fit_predict` in the `GMM` class.

git grep ABCMeta to see how we create abstract base classes in sklearn that support both Python 2 and Python 3 in the same code base.

@ogrisel
Copy link
Member

ogrisel commented Apr 15, 2015

Ah alright, ignore my last comment, I had an internet connection pbm and could not post it when I first wrote it. Now I see that fixed the problems with the subclasses.

component_1 = lrng.randn(n_samples, n_dim) + mu
X = np.vstack((component_0, component_1))

for m_constructor in (mixture.GMM, mixture.VBGMM, mixture.DPGMM):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thanks for having updated that test.

@ogrisel
Copy link
Member

ogrisel commented Apr 15, 2015

@eyaler does that PR meet your requirements from #4579?

LGTM, +1 for merge on my side.

@clorenz7 could please just add a new entry in the section on the new features for 0.17.dev0 in the doc/whats_new.rst file?

@ogrisel ogrisel changed the title Adding a fit_predict method for the GMM [MRG+1] Adding a fit_predict method for the GMM Apr 15, 2015
@clorenz7
Copy link
Author

@ogrisel Added what's new. Thanks for all your help!

@@ -480,7 +480,7 @@ def _set_weights(self):
+ self.gamma_[i, 2])
self.weights_ /= np.sum(self.weights_)

def fit(self, X, y=None):
def _fit(self, X, y=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to document its return value.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks.

@clorenz7 clorenz7 force-pushed the gmm_fit_predict branch 2 times, most recently from de68a27 to 96bede3 Compare April 26, 2015 23:22
@clorenz7
Copy link
Author

@amueller @ogrisel Sorry for the delay in my response. I hope this looks better now.


Returns
-------
C : array, shape = (n_samples,)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

component_membership? Or a docstring?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I added that commentary, thanks.

@amueller
Copy link
Member

LGTM apart from my minor comments.

With low iterations, the prediction might not be 100% accurate due to
the final maximization step in the EM algorithm.
@amueller
Copy link
Member

Let's merge when travis is happy.

amueller added a commit that referenced this pull request Apr 30, 2015
[MRG+1] Adding a fit_predict method for the GMM
@amueller amueller merged commit e1fd955 into scikit-learn:master Apr 30, 2015
@amueller
Copy link
Member

thanks.

argument init_params to the empty string '' when creating the
GMM object. Likewise, if you would like just to do an
initialization, set n_iter=0.
def _fit(self, X, y=None, do_prediction=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand why there is an additional parameter do_prediction=False. @clorenz7

@xuewei4d xuewei4d mentioned this pull request Jun 17, 2015
16 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants