-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Gaussian mixture bypassing unnecessary initialization computing #26021
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Gaussian mixture bypassing unnecessary initialization computing #26021
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @jiawei-zhang-a !
Please add an entry to the change log at doc/whats_new/v1.3.rst
with tag |Efficiency|. Like the other entries there, please reference this pull request with :pr:
and credit yourself (and other contributors if applicable) with :user:
.
Thank you Mr. Fan @thomasjpfan . I have removed the new state and a new changelog :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR @jiawei-zhang-a ! We still need a test to make sure that the parameters are not estimated during initialization. I think a simple way is to monkeypatching:
def test_gaussian_mixture_all_init_does_not_estimate_gaussian_parameters(monkeypatch):
"""When all init are provided, the Gaussian parameters are not estimated.
Non-regression test for gh26015.
"""
mock = Mock(side_effect=_estimate_gaussian_parameters)
monkeypatch.setattr(
sklearn.mixture._gaussian_mixture, "_estimate_gaussian_parameters", mock
)
rng = np.random.RandomState(0)
rand_data = RandomData(rng)
gm = GaussianMixture(
n_components=rand_data.n_components,
weights_init=rand_data.weights,
means_init=rand_data.means,
precisions_init=rand_data.precisions["full"],
random_state=rng,
)
gm.fit(rand_data.X["full"])
# The initial gaussian parameters are not estimated. They are estimated for every
# m_step.
assert mock.call_count == gm.n_iter_
Mock
is from Python's untitest.mock module.
On main
, the test would fail where mock.call_count
is gm.n_iter_+1
from the extra call during initialization.
doc/whats_new/v1.3.rst
Outdated
@@ -43,6 +43,7 @@ random sampling procedures. | |||
:user:`Jérémie du Boisberranger <jeremiedbb>`, | |||
:user:`Guillaume Lemaitre <glemaitre>`. | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For git blame
purpose, can you revert this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure!
Dear Mr.Fan, Thank you so much for all the advice! I will do it following your words! |
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few minor comments, otherwise LGTM
@@ -34,7 +34,8 @@ | |||
from sklearn.utils._testing import assert_array_almost_equal | |||
from sklearn.utils._testing import assert_array_equal | |||
from sklearn.utils._testing import ignore_warnings | |||
|
|||
from unittest.mock import Mock |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Can you move this import to line 9 below import warning
? This way the "first party Python modules" are at the top of the file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! I will do that immediately
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
…cikit-learn into GaussianMixture
Good job! Waiting for this to merge |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @jiawei-zhang-a. Could you kindly resolve the conflicts by merging main and have a look at these few comments?
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
@OmarManzoor Thank you so much for your review! I have committed your suggestions and fix the conflict with main branch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates. I added a few more comments otherwise this looks good now!
Thank you so much! I |
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
Sure! I will check that now |
Reference Issues/PRs
Fixes #26015
What does this implement/fix? Explain your changes.
I add a private variable
_init_weights_means_precisions_skipped
in _base.py.if a user is passing some initial values for the weights, means, and precision then there is no need to run the initialization (via K-means or random) to estimate the gaussian parameters.
These two steps are now skipped if
_init_weights_means_precisions_skipped
isTrue
Any other comments?