Skip to content

[MRG+2] Correct Deprecation of DPGMM and VBGMM. #7124

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions sklearn/mixture/dpgmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,7 @@ def _bound_state_log_lik(X, initial_bound, precs, means, covariance_type):
return bound


@deprecated("The DPGMM class is not working correctly and it's better "
"to not use it. DPGMM is deprecated in 0.18 and "
"will be removed in 0.20.")
class DPGMM(_GMMBase):
class _DPGMMBase(_GMMBase):
"""Variational Inference for the Infinite Gaussian Mixture Model.

DPGMM stands for Dirichlet Process Gaussian Mixture Model, and it
Expand Down Expand Up @@ -211,16 +208,16 @@ class DPGMM(_GMMBase):
algorithm, better for situations where there might be too little
data to get a good estimate of the covariance matrix.
"""

def __init__(self, n_components=1, covariance_type='diag', alpha=1.0,
random_state=None, tol=1e-3, verbose=0, min_covar=None,
n_iter=10, params='wmc', init_params='wmc'):
self.alpha = alpha
super(DPGMM, self).__init__(n_components, covariance_type,
random_state=random_state,
tol=tol, min_covar=min_covar,
n_iter=n_iter, params=params,
init_params=init_params, verbose=verbose)
super(_DPGMMBase, self).__init__(n_components, covariance_type,
random_state=random_state,
tol=tol, min_covar=min_covar,
n_iter=n_iter, params=params,
init_params=init_params,
verbose=verbose)

def _get_precisions(self):
"""Return precisions as a full matrix."""
Expand Down Expand Up @@ -619,10 +616,24 @@ def _fit(self, X, y=None):
return z


@deprecated("The VBGMM class is not working correctly and it's better"
" to not use it. VBGMM is deprecated in 0.18 and "
@deprecated("The DPGMM class is not working correctly and it's better "
"to not use it. DPGMM is deprecated in 0.18 and "
"will be removed in 0.20.")
class DPGMM(_DPGMMBase):
def __init__(self, n_components=1, covariance_type='diag', alpha=1.0,
random_state=None, tol=1e-3, verbose=0, min_covar=None,
n_iter=10, params='wmc', init_params='wmc'):
super(DPGMM, self).__init__(
n_components=n_components, covariance_type=covariance_type,
alpha=alpha, random_state=random_state, tol=tol, verbose=verbose,
min_covar=min_covar, n_iter=n_iter, params=params,
init_params=init_params)


@deprecated("The VBGMM class is not working correctly and it's better "
"to not use it. VBGMM is deprecated in 0.18 and "
"will be removed in 0.20.")
class VBGMM(DPGMM):
class VBGMM(_DPGMMBase):
"""Variational Inference for the Gaussian Mixture Model

Variational inference for a Gaussian mixture model probability
Expand Down
9 changes: 4 additions & 5 deletions sklearn/mixture/tests/test_dpgmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,11 @@ class TestDPGMMWithFullCovars(unittest.TestCase, DPGMMTester):
setUp = GMMTester._setUp


@ignore_warnings(category=DeprecationWarning)
def test_VBGMM_deprecation():
assert_warns_message(DeprecationWarning, "The VBGMM class is"
" not working correctly and it's better"
" to not use it. VBGMM is deprecated in 0.18"
" and will be removed in 0.20.", VBGMM)
assert_warns_message(
DeprecationWarning,
"The VBGMM class is not working correctly and it's better to not use "
"it. VBGMM is deprecated in 0.18 and will be removed in 0.20.", VBGMM)


class VBGMMTester(GMMTester):
Expand Down
2 changes: 1 addition & 1 deletion sklearn/mixture/tests/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ def test_train_1d(self, params='wmc'):
with ignore_warnings(category=DeprecationWarning):
g.fit(X)
trainll = g.score(X)
if isinstance(g, mixture.DPGMM):
if isinstance(g, mixture.dpgmm._DPGMMBase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why that change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DPGMM and VBGMM are both deprecated and as VBGMM inherit from DPGMM you have two warnings messages when you call VBGMM.
I've corrected that to make sure the deprecation warning is meaningfull and moreover to be sure to pass the test once I will modify the deprecation message to refer to the new BayesianGaussianMixture class.

self.assertTrue(np.sum(np.abs(trainll / 100)) < 5)
else:
self.assertTrue(np.sum(np.abs(trainll / 100)) < 2)
Expand Down