-
-
Notifications
You must be signed in to change notification settings - Fork 26k
TST use global_dtype in sklearn/cluster/tests/test_affinity_propagation.py #22667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
542f2c7
76df15d
d1fed63
4272be8
168eba5
88780f9
3ee4572
fe5b107
6684b79
30b8a3e
337fa79
7978f2a
0ea2b5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,10 +29,14 @@ | |
random_state=0, | ||
) | ||
|
||
# TODO: AffinityPropagation must preserve dtype for its fitted attributes | ||
# and test must be created accordingly to this new behavior. | ||
# For more details, see: https://github.com/scikit-learn/scikit-learn/issues/11000 | ||
|
||
def test_affinity_propagation(global_random_seed): | ||
|
||
def test_affinity_propagation(global_random_seed, global_dtype): | ||
"""Test consistency of the affinity propagations.""" | ||
S = -euclidean_distances(X, squared=True) | ||
S = -euclidean_distances(X.astype(global_dtype, copy=False), squared=True) | ||
preference = np.median(S) * 10 | ||
cluster_centers_indices, labels = affinity_propagation( | ||
S, preference=preference, random_state=global_random_seed | ||
|
@@ -108,11 +112,12 @@ def test_affinity_propagation_precomputed_with_sparse_input(): | |
AffinityPropagation(affinity="precomputed").fit(csr_matrix((3, 3))) | ||
|
||
|
||
def test_affinity_propagation_predict(global_random_seed): | ||
def test_affinity_propagation_predict(global_random_seed, global_dtype): | ||
# Test AffinityPropagation.predict | ||
af = AffinityPropagation(affinity="euclidean", random_state=global_random_seed) | ||
labels = af.fit_predict(X) | ||
labels2 = af.predict(X) | ||
X_ = X.astype(global_dtype, copy=False) | ||
labels = af.fit_predict(X_) | ||
labels2 = af.predict(X_) | ||
assert_array_equal(labels, labels2) | ||
|
||
|
||
|
@@ -131,23 +136,23 @@ def test_affinity_propagation_predict_error(): | |
af.predict(X) | ||
|
||
|
||
def test_affinity_propagation_fit_non_convergence(): | ||
def test_affinity_propagation_fit_non_convergence(global_dtype): | ||
jjerphan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# In case of non-convergence of affinity_propagation(), the cluster | ||
# centers should be an empty array and training samples should be labelled | ||
# as noise (-1) | ||
X = np.array([[0, 0], [1, 1], [-2, -2]]) | ||
X = np.array([[0, 0], [1, 1], [-2, -2]], dtype=global_dtype) | ||
|
||
# Force non-convergence by allowing only a single iteration | ||
af = AffinityPropagation(preference=-10, max_iter=1, random_state=82) | ||
|
||
with pytest.warns(ConvergenceWarning): | ||
af.fit(X) | ||
assert_array_equal(np.empty((0, 2)), af.cluster_centers_) | ||
assert_allclose(np.empty((0, 2)), af.cluster_centers_) | ||
assert_array_equal(np.array([-1, -1, -1]), af.labels_) | ||
|
||
|
||
def test_affinity_propagation_equal_mutual_similarities(): | ||
X = np.array([[-1, 1], [1, -1]]) | ||
def test_affinity_propagation_equal_mutual_similarities(global_dtype): | ||
X = np.array([[-1, 1], [1, -1]], dtype=global_dtype) | ||
S = -euclidean_distances(X, squared=True) | ||
|
||
# setting preference > similarity | ||
|
@@ -178,10 +183,10 @@ def test_affinity_propagation_equal_mutual_similarities(): | |
assert_array_equal([0, 0], labels) | ||
|
||
|
||
def test_affinity_propagation_predict_non_convergence(): | ||
def test_affinity_propagation_predict_non_convergence(global_dtype): | ||
jjerphan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# In case of non-convergence of affinity_propagation(), the cluster | ||
# centers should be an empty array | ||
X = np.array([[0, 0], [1, 1], [-2, -2]]) | ||
X = np.array([[0, 0], [1, 1], [-2, -2]], dtype=global_dtype) | ||
|
||
# Force non-convergence by allowing only a single iteration | ||
with pytest.warns(ConvergenceWarning): | ||
|
@@ -195,8 +200,10 @@ def test_affinity_propagation_predict_non_convergence(): | |
assert_array_equal(np.array([-1, -1, -1]), y) | ||
|
||
|
||
def test_affinity_propagation_non_convergence_regressiontest(): | ||
X = np.array([[1, 0, 0, 0, 0, 0], [0, 1, 1, 1, 0, 0], [0, 0, 1, 0, 0, 1]]) | ||
def test_affinity_propagation_non_convergence_regressiontest(global_dtype): | ||
jjerphan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
X = np.array( | ||
[[1, 0, 0, 0, 0, 0], [0, 1, 1, 1, 0, 0], [0, 0, 1, 0, 0, 1]], dtype=global_dtype | ||
) | ||
af = AffinityPropagation(affinity="euclidean", max_iter=2, random_state=34) | ||
msg = ( | ||
"Affinity propagation did not converge, this model may return degenerate" | ||
|
@@ -208,17 +215,17 @@ def test_affinity_propagation_non_convergence_regressiontest(): | |
assert_array_equal(np.array([0, 0, 0]), af.labels_) | ||
|
||
|
||
def test_equal_similarities_and_preferences(): | ||
def test_equal_similarities_and_preferences(global_dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The previous test ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Which difference of dtypes are you referring to? |
||
# Unequal distances | ||
X = np.array([[0, 0], [1, 1], [-2, -2]]) | ||
X = np.array([[0, 0], [1, 1], [-2, -2]], dtype=global_dtype) | ||
S = -euclidean_distances(X, squared=True) | ||
|
||
jjerphan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert not _equal_similarities_and_preferences(S, np.array(0)) | ||
jjerphan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
assert not _equal_similarities_and_preferences(S, np.array([0, 0])) | ||
assert not _equal_similarities_and_preferences(S, np.array([0, 1])) | ||
|
||
# Equal distances | ||
X = np.array([[0, 0], [1, 1]]) | ||
X = np.array([[0, 0], [1, 1]], dtype=global_dtype) | ||
S = -euclidean_distances(X, squared=True) | ||
|
||
# Different preferences | ||
|
@@ -251,10 +258,14 @@ def test_affinity_propagation_random_state(): | |
|
||
|
||
@pytest.mark.parametrize("centers", [csr_matrix(np.zeros((1, 10))), np.zeros((1, 10))]) | ||
def test_affinity_propagation_convergence_warning_dense_sparse(centers): | ||
"""Non-regression, see #13334""" | ||
def test_affinity_propagation_convergence_warning_dense_sparse(centers, global_dtype): | ||
jeremiedbb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
Check that having sparse or dense `centers` format should not | ||
influence the convergence. | ||
Non-regression test for gh-13334. | ||
""" | ||
rng = np.random.RandomState(42) | ||
X = rng.rand(40, 10) | ||
X = rng.rand(40, 10).astype(global_dtype, copy=False) | ||
y = (4 * rng.rand(40)).astype(int) | ||
ap = AffinityPropagation(random_state=46) | ||
ap.fit(X, y) | ||
jjerphan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -265,11 +276,11 @@ def test_affinity_propagation_convergence_warning_dense_sparse(centers): | |
|
||
|
||
# FIXME; this test is broken with different random states, needs to be revisited | ||
def test_affinity_propagation_float32(): | ||
def test_correct_clusters(global_dtype): | ||
# Test to fix incorrect clusters due to dtype change | ||
# (non-regression test for issue #10832) | ||
X = np.array( | ||
[[1, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 1]], dtype="float32" | ||
[[1, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 1]], dtype=global_dtype | ||
) | ||
afp = AffinityPropagation(preference=1, affinity="precomputed", random_state=0).fit( | ||
X | ||
|
Uh oh!
There was an error while loading. Please reload this page.