-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
ENH Add dtype preservation for SpectralClustering
#22669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
582aa0b
TST Adapt test_spectral.py to test implementations on 32bit datasets
jjerphan 2a0b747
Merge branch 'main' into tst/test_spectral-32bit
jjerphan 818e85c
TST Use global_dtype
jjerphan dcc03f5
Address reviews comments
jjerphan b60a65f
Merge branch 'main' into tst/test_spectral-32bit
jjerphan 1d05d2d
Make sure to preserve dtype in SpectralClustering
jjerphan ac6dce4
Merge branch 'main' into tst/test_spectral-32bit
jjerphan File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,7 +40,7 @@ | |
|
||
@pytest.mark.parametrize("eigen_solver", ("arpack", "lobpcg")) | ||
@pytest.mark.parametrize("assign_labels", ("kmeans", "discretize", "cluster_qr")) | ||
def test_spectral_clustering(eigen_solver, assign_labels): | ||
def test_spectral_clustering(eigen_solver, assign_labels, global_dtype): | ||
S = np.array( | ||
[ | ||
[1.0, 1.0, 1.0, 0.2, 0.0, 0.0, 0.0], | ||
|
@@ -50,7 +50,8 @@ def test_spectral_clustering(eigen_solver, assign_labels): | |
[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], | ||
[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], | ||
[0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], | ||
] | ||
], | ||
dtype=global_dtype, | ||
) | ||
|
||
for mat in (S, sparse.csr_matrix(S)): | ||
|
@@ -74,14 +75,14 @@ def test_spectral_clustering(eigen_solver, assign_labels): | |
|
||
|
||
@pytest.mark.parametrize("assign_labels", ("kmeans", "discretize", "cluster_qr")) | ||
def test_spectral_clustering_sparse(assign_labels): | ||
def test_spectral_clustering_sparse(assign_labels, global_dtype): | ||
X, y = make_blobs( | ||
n_samples=20, random_state=0, centers=[[1, 1], [-1, -1]], cluster_std=0.01 | ||
) | ||
|
||
S = rbf_kernel(X, gamma=1) | ||
S = np.maximum(S - 1e-4, 0) | ||
S = sparse.coo_matrix(S) | ||
S = sparse.coo_matrix(S, dtype=global_dtype) | ||
|
||
labels = ( | ||
SpectralClustering( | ||
|
@@ -96,11 +97,12 @@ def test_spectral_clustering_sparse(assign_labels): | |
assert adjusted_rand_score(y, labels) == 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check the dtype of |
||
|
||
|
||
def test_precomputed_nearest_neighbors_filtering(): | ||
def test_precomputed_nearest_neighbors_filtering(global_dtype): | ||
# Test precomputed graph filtering when containing too many neighbors | ||
X, y = make_blobs( | ||
n_samples=200, random_state=0, centers=[[1, 1], [-1, -1]], cluster_std=0.01 | ||
) | ||
X = X.astype(global_dtype, copy=False) | ||
|
||
n_neighbors = 2 | ||
results = [] | ||
|
@@ -122,13 +124,14 @@ def test_precomputed_nearest_neighbors_filtering(): | |
assert_array_equal(results[0], results[1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check the dtype of |
||
|
||
|
||
def test_affinities(): | ||
def test_affinities(global_dtype): | ||
# Note: in the following, random_state has been selected to have | ||
# a dataset that yields a stable eigen decomposition both when built | ||
# on OSX and Linux | ||
X, y = make_blobs( | ||
n_samples=20, random_state=0, centers=[[1, 1], [-1, -1]], cluster_std=0.01 | ||
) | ||
X = X.astype(global_dtype, copy=False) | ||
# nearest neighbors affinity | ||
sp = SpectralClustering(n_clusters=2, affinity="nearest_neighbors", random_state=0) | ||
with pytest.warns(UserWarning, match="not fully connected"): | ||
|
@@ -140,6 +143,7 @@ def test_affinities(): | |
assert adjusted_rand_score(y, labels) == 1 | ||
|
||
X = check_random_state(10).rand(10, 5) * 10 | ||
X = X.astype(global_dtype, copy=False) | ||
|
||
kernels_available = kernel_metrics() | ||
for kern in kernels_available: | ||
|
@@ -182,13 +186,13 @@ def test_cluster_qr(): | |
assert np.array_equal(labels_float64, labels_float32) | ||
|
||
|
||
def test_cluster_qr_permutation_invariance(): | ||
def test_cluster_qr_permutation_invariance(global_dtype): | ||
# cluster_qr must be invariant to sample permutation. | ||
random_state = np.random.RandomState(seed=8) | ||
n_samples, n_components = 100, 5 | ||
data = random_state.randn(n_samples, n_components) | ||
data = random_state.randn(n_samples, n_components).astype(global_dtype, copy=False) | ||
perm = random_state.permutation(n_samples) | ||
assert np.array_equal( | ||
assert_array_equal( | ||
cluster_qr(data)[perm], | ||
cluster_qr(data[perm]), | ||
) | ||
|
@@ -263,12 +267,13 @@ def test_spectral_clustering_with_arpack_amg_solvers(): | |
spectral_clustering(graph, n_clusters=2, eigen_solver="amg", random_state=0) | ||
|
||
|
||
def test_n_components(): | ||
def test_n_components(global_dtype): | ||
# Test that after adding n_components, result is different and | ||
# n_components = n_clusters by default | ||
X, y = make_blobs( | ||
n_samples=20, random_state=0, centers=[[1, 1], [-1, -1]], cluster_std=0.01 | ||
) | ||
X = X.astype(global_dtype, copy=False) | ||
sp = SpectralClustering(n_clusters=2, random_state=0) | ||
labels = sp.fit(X).labels_ | ||
# set n_components = n_cluster and test if result is the same | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should work on making
pairwise_kernels
work efficiently withfloat32
data first.Since
affinity="rbf"
is the default, flaggingSpectralClustering
asdtype
preserving without that prerequisite would be misleading to our users: there would be very few performance or peak memory usage gains in passingfloat32
data to this estimator when using the default hyper-params.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW we also need a test that checks the dtype of
affinity_matrix_
when theaffinity
param is the str name of a kernel function and another such assertion in a test that covers the case when affinity isnearest_neighbors
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We also need to have
ArgKmin
andRadiusNeighbors
preserve dtypes because it is also used whenaffinity="precomputed_nearest_neighbors"
.Let's turn this PR as a draft for now.