Skip to content

ENH Gaussian mixture bypassing unnecessary initialization computing #26021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
bd09eb8
Bypassing initialization when initial values weights, means, and prec…
jiawei-zhang-a Mar 29, 2023
cbe04d6
Bypassing initialization when initial values weights, means, and prec…
jiawei-zhang-a Mar 29, 2023
c978dbf
change the variable name to meet the standard
jiawei-zhang-a Mar 31, 2023
671eb35
Merge branch 'main' into GaussianMixture
jiawei-zhang-a Mar 31, 2023
f86fbba
Merged
jiawei-zhang-a Mar 31, 2023
2c99237
make the parameter private for base and GaussianMixture
jiawei-zhang-a Mar 31, 2023
1aae098
Merge branch 'scikit-learn:main' into GaussianMixture
jiawei-zhang-a Apr 14, 2023
5cb80f6
no var
jiawei-zhang-a Apr 14, 2023
b480670
Update sklearn/mixture/_gaussian_mixture.py
jiawei-zhang-a Apr 28, 2023
b95b2f7
Add test for bypassing
jiawei-zhang-a Apr 28, 2023
83b7a7f
Update sklearn/mixture/tests/test_gaussian_mixture.py
jiawei-zhang-a Apr 28, 2023
403031c
small adjustment
jiawei-zhang-a Apr 28, 2023
efd7176
Merge branch 'GaussianMixture' of https://github.com/jiawei-zhang-a/s…
jiawei-zhang-a Apr 28, 2023
bef22c6
Update doc/whats_new/v1.3.rst
jiawei-zhang-a Aug 9, 2023
a080d40
Update sklearn/mixture/_gaussian_mixture.py
jiawei-zhang-a Aug 9, 2023
f54f541
Fix the conflict
jiawei-zhang-a Aug 9, 2023
67a2640
Update doc/whats_new/v1.3.rst
jiawei-zhang-a Aug 10, 2023
3e3d828
Update sklearn/mixture/tests/test_gaussian_mixture.py
jiawei-zhang-a Aug 10, 2023
981301c
Update sklearn/mixture/tests/test_gaussian_mixture.py
jiawei-zhang-a Aug 10, 2023
9f38ea7
Update sklearn/mixture/tests/test_gaussian_mixture.py
jiawei-zhang-a Aug 10, 2023
9ce05ce
change line for > 80 character
jiawei-zhang-a Aug 10, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,13 @@ Changelog
`n_targets`, which is used to decide the number of outputs when sampling
from the prior distributions. :pr:`23099` by :user:`Zhehao Liu <MaxwellLZH>`.

:mod:`sklearn.mixture`
......................

- |Efficiency| :class:`GaussianMixture` is more efficient now and will bypass unnecessary
initialization if the weights, means, and precisions are given by users.
:pr:`26021` by :user:`Jiawei Zhang <jiawei-zhang-a>`.

:mod:`sklearn.model_selection`
..............................

Expand Down
25 changes: 20 additions & 5 deletions sklearn/mixture/_gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,19 @@ def _check_parameters(self, X):
n_features,
)

def _initialize_parameters(self, X, random_state):
# If all the initial parameters are all provided, then there is no need to run
# the initialization.
compute_resp = (
self.weights_init is None
or self.means_init is None
or self.precisions_init is None
)
if compute_resp:
super()._initialize_parameters(X, random_state)
else:
self._initialize(X, None)

def _initialize(self, X, resp):
"""Initialization of the Gaussian mixture parameters.

Expand All @@ -764,11 +777,13 @@ def _initialize(self, X, resp):
resp : array-like of shape (n_samples, n_components)
"""
n_samples, _ = X.shape

weights, means, covariances = _estimate_gaussian_parameters(
X, resp, self.reg_covar, self.covariance_type
)
weights /= n_samples
weights, means, covariances = None, None, None
if resp is not None:
weights, means, covariances = _estimate_gaussian_parameters(
X, resp, self.reg_covar, self.covariance_type
)
if self.weights_init is None:
weights /= n_samples

self.weights_ = weights if self.weights_init is None else self.weights_init
self.means_ = means if self.means_init is None else self.means_init
Expand Down
33 changes: 33 additions & 0 deletions sklearn/mixture/tests/test_gaussian_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import sys
import warnings
from io import StringIO
from unittest.mock import Mock

import numpy as np
import pytest
from scipy import linalg, stats

import sklearn
from sklearn.cluster import KMeans
from sklearn.covariance import EmpiricalCovariance
from sklearn.datasets import make_spd_matrix
Expand Down Expand Up @@ -1387,3 +1389,34 @@ def test_gaussian_mixture_single_component_stable():
X = rng.multivariate_normal(np.zeros(2), np.identity(2), size=3)
gm = GaussianMixture(n_components=1)
gm.fit(X).sample()


def test_gaussian_mixture_all_init_does_not_estimate_gaussian_parameters(
monkeypatch,
global_random_seed,
):
"""When all init parameters are provided, the Gaussian parameters
are not estimated.

Non-regression test for gh-26015.
"""

mock = Mock(side_effect=_estimate_gaussian_parameters)
monkeypatch.setattr(
sklearn.mixture._gaussian_mixture, "_estimate_gaussian_parameters", mock
)

rng = np.random.RandomState(global_random_seed)
rand_data = RandomData(rng)

gm = GaussianMixture(
n_components=rand_data.n_components,
weights_init=rand_data.weights,
means_init=rand_data.means,
precisions_init=rand_data.precisions["full"],
random_state=rng,
)
gm.fit(rand_data.X["full"])
# The initial gaussian parameters are not estimated. They are estimated for every
# m_step.
assert mock.call_count == gm.n_iter_