diff --git a/sklearn/mixture/dpgmm.py b/sklearn/mixture/dpgmm.py index e0db5e568f013..700b63fdad31e 100644 --- a/sklearn/mixture/dpgmm.py +++ b/sklearn/mixture/dpgmm.py @@ -116,10 +116,7 @@ def _bound_state_log_lik(X, initial_bound, precs, means, covariance_type): return bound -@deprecated("The DPGMM class is not working correctly and it's better " - "to not use it. DPGMM is deprecated in 0.18 and " - "will be removed in 0.20.") -class DPGMM(_GMMBase): +class _DPGMMBase(_GMMBase): """Variational Inference for the Infinite Gaussian Mixture Model. DPGMM stands for Dirichlet Process Gaussian Mixture Model, and it @@ -211,16 +208,16 @@ class DPGMM(_GMMBase): algorithm, better for situations where there might be too little data to get a good estimate of the covariance matrix. """ - def __init__(self, n_components=1, covariance_type='diag', alpha=1.0, random_state=None, tol=1e-3, verbose=0, min_covar=None, n_iter=10, params='wmc', init_params='wmc'): self.alpha = alpha - super(DPGMM, self).__init__(n_components, covariance_type, - random_state=random_state, - tol=tol, min_covar=min_covar, - n_iter=n_iter, params=params, - init_params=init_params, verbose=verbose) + super(_DPGMMBase, self).__init__(n_components, covariance_type, + random_state=random_state, + tol=tol, min_covar=min_covar, + n_iter=n_iter, params=params, + init_params=init_params, + verbose=verbose) def _get_precisions(self): """Return precisions as a full matrix.""" @@ -619,10 +616,24 @@ def _fit(self, X, y=None): return z -@deprecated("The VBGMM class is not working correctly and it's better" - " to not use it. VBGMM is deprecated in 0.18 and " +@deprecated("The DPGMM class is not working correctly and it's better " + "to not use it. DPGMM is deprecated in 0.18 and " + "will be removed in 0.20.") +class DPGMM(_DPGMMBase): + def __init__(self, n_components=1, covariance_type='diag', alpha=1.0, + random_state=None, tol=1e-3, verbose=0, min_covar=None, + n_iter=10, params='wmc', init_params='wmc'): + super(DPGMM, self).__init__( + n_components=n_components, covariance_type=covariance_type, + alpha=alpha, random_state=random_state, tol=tol, verbose=verbose, + min_covar=min_covar, n_iter=n_iter, params=params, + init_params=init_params) + + +@deprecated("The VBGMM class is not working correctly and it's better " + "to not use it. VBGMM is deprecated in 0.18 and " "will be removed in 0.20.") -class VBGMM(DPGMM): +class VBGMM(_DPGMMBase): """Variational Inference for the Gaussian Mixture Model Variational inference for a Gaussian mixture model probability diff --git a/sklearn/mixture/tests/test_dpgmm.py b/sklearn/mixture/tests/test_dpgmm.py index 88a67170f8268..c636403db7672 100644 --- a/sklearn/mixture/tests/test_dpgmm.py +++ b/sklearn/mixture/tests/test_dpgmm.py @@ -182,12 +182,11 @@ class TestDPGMMWithFullCovars(unittest.TestCase, DPGMMTester): setUp = GMMTester._setUp -@ignore_warnings(category=DeprecationWarning) def test_VBGMM_deprecation(): - assert_warns_message(DeprecationWarning, "The VBGMM class is" - " not working correctly and it's better" - " to not use it. VBGMM is deprecated in 0.18" - " and will be removed in 0.20.", VBGMM) + assert_warns_message( + DeprecationWarning, + "The VBGMM class is not working correctly and it's better to not use " + "it. VBGMM is deprecated in 0.18 and will be removed in 0.20.", VBGMM) class VBGMMTester(GMMTester): diff --git a/sklearn/mixture/tests/test_gmm.py b/sklearn/mixture/tests/test_gmm.py index f6765d18b4dbb..500e11117a5a1 100644 --- a/sklearn/mixture/tests/test_gmm.py +++ b/sklearn/mixture/tests/test_gmm.py @@ -309,7 +309,7 @@ def test_train_1d(self, params='wmc'): with ignore_warnings(category=DeprecationWarning): g.fit(X) trainll = g.score(X) - if isinstance(g, mixture.DPGMM): + if isinstance(g, mixture.dpgmm._DPGMMBase): self.assertTrue(np.sum(np.abs(trainll / 100)) < 5) else: self.assertTrue(np.sum(np.abs(trainll / 100)) < 2)