From 797157ea8d2355ec5c3468f6eaff25acdc82257d Mon Sep 17 00:00:00 2001 From: akikuno Date: Sun, 5 May 2024 16:43:08 +0900 Subject: [PATCH 1/9] Loosened to `dist <= stop_thresh` to converge in on 1D constant data #28926 --- sklearn/cluster/_mean_shift.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sklearn/cluster/_mean_shift.py b/sklearn/cluster/_mean_shift.py index fae11cca7df23..353c69c7a0279 100644 --- a/sklearn/cluster/_mean_shift.py +++ b/sklearn/cluster/_mean_shift.py @@ -121,10 +121,7 @@ def _mean_shift_single_seed(my_mean, X, nbrs, max_iter): my_old_mean = my_mean # save the old mean my_mean = np.mean(points_within, axis=0) # If converged or at max_iter, adds the cluster - if ( - np.linalg.norm(my_mean - my_old_mean) < stop_thresh - or completed_iterations == max_iter - ): + if np.linalg.norm(my_mean - my_old_mean) <= stop_thresh or completed_iterations == max_iter: break completed_iterations += 1 return tuple(my_mean), len(points_within), completed_iterations From 2baa163bdac2022ee8dba53f7308cd9988a7db53 Mon Sep 17 00:00:00 2001 From: Akihiro Kuno Date: Sun, 5 May 2024 16:53:25 +0900 Subject: [PATCH 2/9] Linted using black --- sklearn/cluster/_mean_shift.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sklearn/cluster/_mean_shift.py b/sklearn/cluster/_mean_shift.py index 353c69c7a0279..a99a607f3cf0d 100644 --- a/sklearn/cluster/_mean_shift.py +++ b/sklearn/cluster/_mean_shift.py @@ -121,7 +121,10 @@ def _mean_shift_single_seed(my_mean, X, nbrs, max_iter): my_old_mean = my_mean # save the old mean my_mean = np.mean(points_within, axis=0) # If converged or at max_iter, adds the cluster - if np.linalg.norm(my_mean - my_old_mean) <= stop_thresh or completed_iterations == max_iter: + if ( + np.linalg.norm(my_mean - my_old_mean) <= stop_thresh + or completed_iterations == max_iter + ): break completed_iterations += 1 return tuple(my_mean), len(points_within), completed_iterations From 1495676dbe53ce0a0effe7b0ed364d2a29b7ead2 Mon Sep 17 00:00:00 2001 From: Akihiro Kuno Date: Mon, 6 May 2024 18:53:08 +0900 Subject: [PATCH 3/9] Add changelog of `MeanShift` enhancement #28951 --- doc/whats_new/v1.5.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index ede5d5dcbf1ec..3980f7a80e680 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -43,6 +43,9 @@ Security Changed models -------------- +- |Efficiency| The `clustering.MeanShift` class has now improved computational speed as it properly converges for constant data. + :pr:`28951` by :user:`Akihiro Kuno `. + - |Efficiency| The subsampling in :class:`preprocessing.QuantileTransformer` is now more efficient for dense arrays but the fitted quantiles and the results of `transform` may be slightly different than before (keeping the same statistical From 20caedae6aef18544ef93e40eed48b0dbc706e8b Mon Sep 17 00:00:00 2001 From: Akihiro Kuno Date: Mon, 6 May 2024 18:54:04 +0900 Subject: [PATCH 4/9] Add tests for `MeanShift` to ensure convergence with constant data. --- sklearn/cluster/tests/test_mean_shift.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sklearn/cluster/tests/test_mean_shift.py b/sklearn/cluster/tests/test_mean_shift.py index 265c72d0c4ce1..ed6879a76bf12 100644 --- a/sklearn/cluster/tests/test_mean_shift.py +++ b/sklearn/cluster/tests/test_mean_shift.py @@ -25,6 +25,19 @@ ) +def test_convergence_of_1d_constant_data(): + # Test convergence using 1D constant data + x = np.concatenate([np.zeros(10), np.ones(10)]) + n_iter = MeanShift().fit(x.reshape(-1,1)).n_iter_ + assert n_iter < 300 + +def test_convergence_of_2d_constant_data(): + # Test convergence using 2D constant data + x = np.concatenate([np.zeros((10, 10)), np.ones((10, 10))]) + n_iter = MeanShift().fit(x).n_iter_ + assert n_iter < 300 + + def test_estimate_bandwidth(): # Test estimate_bandwidth bandwidth = estimate_bandwidth(X, n_samples=200) From 6ed39ae8434c35e5b03766f60f8ba9569dc0e428 Mon Sep 17 00:00:00 2001 From: Akihiro Kuno Date: Tue, 7 May 2024 08:39:51 +0900 Subject: [PATCH 5/9] Update sklearn/cluster/tests/test_mean_shift.py Co-authored-by: Olivier Grisel --- sklearn/cluster/tests/test_mean_shift.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/cluster/tests/test_mean_shift.py b/sklearn/cluster/tests/test_mean_shift.py index ed6879a76bf12..65b3dca976442 100644 --- a/sklearn/cluster/tests/test_mean_shift.py +++ b/sklearn/cluster/tests/test_mean_shift.py @@ -31,6 +31,7 @@ def test_convergence_of_1d_constant_data(): n_iter = MeanShift().fit(x.reshape(-1,1)).n_iter_ assert n_iter < 300 + def test_convergence_of_2d_constant_data(): # Test convergence using 2D constant data x = np.concatenate([np.zeros((10, 10)), np.ones((10, 10))]) From 33892d3841cc6214fa38f03e1b01b243a54a6a30 Mon Sep 17 00:00:00 2001 From: Akihiro Kuno Date: Tue, 7 May 2024 08:40:12 +0900 Subject: [PATCH 6/9] Update sklearn/cluster/tests/test_mean_shift.py Co-authored-by: Olivier Grisel --- sklearn/cluster/tests/test_mean_shift.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/cluster/tests/test_mean_shift.py b/sklearn/cluster/tests/test_mean_shift.py index 65b3dca976442..7c02372dcbc72 100644 --- a/sklearn/cluster/tests/test_mean_shift.py +++ b/sklearn/cluster/tests/test_mean_shift.py @@ -27,9 +27,11 @@ def test_convergence_of_1d_constant_data(): # Test convergence using 1D constant data - x = np.concatenate([np.zeros(10), np.ones(10)]) - n_iter = MeanShift().fit(x.reshape(-1,1)).n_iter_ - assert n_iter < 300 + # Non-regression test for: + # https://github.com/scikit-learn/scikit-learn/issues/28926 + model = MeanShift() + n_iter = model.fit(np.ones(10).reshape(-1, 1)).n_iter_ + assert n_iter < model.max_iter def test_convergence_of_2d_constant_data(): From edb19ee37caab2d32144bd85934875b4ad8d06b0 Mon Sep 17 00:00:00 2001 From: Akihiro Kuno Date: Thu, 9 May 2024 10:53:21 +0900 Subject: [PATCH 7/9] Remove the 2d case to test the convergence of constant data --- sklearn/cluster/tests/test_mean_shift.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sklearn/cluster/tests/test_mean_shift.py b/sklearn/cluster/tests/test_mean_shift.py index 7c02372dcbc72..d2d73ba11a3ec 100644 --- a/sklearn/cluster/tests/test_mean_shift.py +++ b/sklearn/cluster/tests/test_mean_shift.py @@ -34,13 +34,6 @@ def test_convergence_of_1d_constant_data(): assert n_iter < model.max_iter -def test_convergence_of_2d_constant_data(): - # Test convergence using 2D constant data - x = np.concatenate([np.zeros((10, 10)), np.ones((10, 10))]) - n_iter = MeanShift().fit(x).n_iter_ - assert n_iter < 300 - - def test_estimate_bandwidth(): # Test estimate_bandwidth bandwidth = estimate_bandwidth(X, n_samples=200) From fb9dd4af3898c3e17cea37564d5197c5d8823a98 Mon Sep 17 00:00:00 2001 From: Akihiro Kuno Date: Thu, 9 May 2024 10:58:38 +0900 Subject: [PATCH 8/9] Apply the suggested change from @jeremiedbb MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Jérémie du Boisberranger --- doc/whats_new/v1.5.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index 3980f7a80e680..ba659e18d1d31 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -43,7 +43,7 @@ Security Changed models -------------- -- |Efficiency| The `clustering.MeanShift` class has now improved computational speed as it properly converges for constant data. +- |Fix| The :class:`cluster.MeanShift` class now properly converges for constant data. :pr:`28951` by :user:`Akihiro Kuno `. - |Efficiency| The subsampling in :class:`preprocessing.QuantileTransformer` is now From 1ecc1f3c73c8f4738cb753dc8d7345c5530880cf Mon Sep 17 00:00:00 2001 From: Akihiro Kuno Date: Thu, 9 May 2024 11:02:18 +0900 Subject: [PATCH 9/9] Move the changelog to the `sklearn.cluster` section --- doc/whats_new/v1.5.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.5.rst b/doc/whats_new/v1.5.rst index ba659e18d1d31..80fa01dd967cc 100644 --- a/doc/whats_new/v1.5.rst +++ b/doc/whats_new/v1.5.rst @@ -43,9 +43,6 @@ Security Changed models -------------- -- |Fix| The :class:`cluster.MeanShift` class now properly converges for constant data. - :pr:`28951` by :user:`Akihiro Kuno `. - - |Efficiency| The subsampling in :class:`preprocessing.QuantileTransformer` is now more efficient for dense arrays but the fitted quantiles and the results of `transform` may be slightly different than before (keeping the same statistical @@ -178,6 +175,9 @@ Changelog :mod:`sklearn.cluster` ...................... +- |Fix| The :class:`cluster.MeanShift` class now properly converges for constant data. + :pr:`28951` by :user:`Akihiro Kuno `. + - |FIX| Create copy of precomputed sparse matrix within the `fit` method of :class:`~cluster.OPTICS` to avoid in-place modification of the sparse matrix. :pr:`28491` by :user:`Thanh Lam Dang `.