-
-
Notifications
You must be signed in to change notification settings - Fork 26k
[MRG+1] Adding a fit_predict method for the GMM #4593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
self.fit(X, y) | ||
|
||
if self.responsibilities_ is None: | ||
raise RuntimeError("Fitting failed, cannot predict") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to use the _check_fitted_model
: do a git grep _check_fitted_model
to see examples in the code base.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually this comment is no longer relevant in light of the other comments.
9349c48
to
18d4df4
Compare
|
||
responsibilities = self._fit(X, y) | ||
if responsibilities is None: | ||
prediction = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How could that ever happen?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would return directly:
return self._fit(X, y).argmax(axis=1)
and make sure that _fit
can never return None
(make it raise a ValueError
or similar with a meaningful error message otherwise).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It could happen in the case when n_iter == 0. You're right that never returning None is better, so I added a check to run score_samples to get the correct output value in that case (that apparently happens when running an HMM). My current idea is just to output zeros because the idea of n_iter=0 seems to be to quickly initialize a model.
I am not sure I understand the travis failures, it probably requires to launch a debugger. |
18d4df4
to
986defe
Compare
@ogrisel Besides addressing your comments, I changed the GMM subclass |
For the travis failure, a solution would be to not implement
|
Ah alright, ignore my last comment, I had an internet connection pbm and could not post it when I first wrote it. Now I see that fixed the problems with the subclasses. |
component_1 = lrng.randn(n_samples, n_dim) + mu | ||
X = np.vstack((component_0, component_1)) | ||
|
||
for m_constructor in (mixture.GMM, mixture.VBGMM, mixture.DPGMM): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Thanks for having updated that test.
986defe
to
234bd07
Compare
@ogrisel Added what's new. Thanks for all your help! |
234bd07
to
ff4d8d2
Compare
@@ -480,7 +480,7 @@ def _set_weights(self): | |||
+ self.gamma_[i, 2]) | |||
self.weights_ /= np.sum(self.weights_) | |||
|
|||
def fit(self, X, y=None): | |||
def _fit(self, X, y=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It needs to document its return value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thanks.
de68a27
to
96bede3
Compare
|
||
Returns | ||
------- | ||
C : array, shape = (n_samples,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
component_membership? Or a docstring?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point, I added that commentary, thanks.
LGTM apart from my minor comments. |
With low iterations, the prediction might not be 100% accurate due to the final maximization step in the EM algorithm.
96bede3
to
fc87eb3
Compare
Let's merge when travis is happy. |
[MRG+1] Adding a fit_predict method for the GMM
thanks. |
argument init_params to the empty string '' when creating the | ||
GMM object. Likewise, if you would like just to do an | ||
initialization, set n_iter=0. | ||
def _fit(self, X, y=None, do_prediction=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand why there is an additional parameter do_prediction=False
. @clorenz7
With low iterations, the prediction might not be 100% accurate due to
the final maximization step in the EM algorithm.
See issue:
#4579