-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG] Fix missing assert and parametrize some k-means tests #12368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -181,12 +181,6 @@ def _check_fitted_model(km): | |
% km.n_clusters, km.fit, [[0., 1.]]) | ||
|
||
|
||
def test_k_means_plus_plus_init(): | ||
km = KMeans(init="k-means++", n_clusters=n_clusters, | ||
random_state=42).fit(X) | ||
_check_fitted_model(km) | ||
|
||
|
||
def test_k_means_new_centers(): | ||
# Explore the part of the code where a new center is reassigned | ||
X = np.array([[0, 0, 1, 1], | ||
|
@@ -229,24 +223,6 @@ def test_k_means_precompute_distances_flag(): | |
assert_raises(ValueError, km.fit, X) | ||
|
||
|
||
def test_k_means_plus_plus_init_sparse(): | ||
km = KMeans(init="k-means++", n_clusters=n_clusters, random_state=42) | ||
km.fit(X_csr) | ||
_check_fitted_model(km) | ||
|
||
|
||
def test_k_means_random_init(): | ||
km = KMeans(init="random", n_clusters=n_clusters, random_state=42) | ||
km.fit(X) | ||
_check_fitted_model(km) | ||
|
||
|
||
def test_k_means_random_init_sparse(): | ||
km = KMeans(init="random", n_clusters=n_clusters, random_state=42) | ||
km.fit(X_csr) | ||
_check_fitted_model(km) | ||
|
||
|
||
def test_k_means_plus_plus_init_not_precomputed(): | ||
km = KMeans(init="k-means++", n_clusters=n_clusters, random_state=42, | ||
precompute_distances=False).fit(X) | ||
|
@@ -259,10 +235,11 @@ def test_k_means_random_init_not_precomputed(): | |
_check_fitted_model(km) | ||
|
||
|
||
def test_k_means_perfect_init(): | ||
km = KMeans(init=centers.copy(), n_clusters=n_clusters, random_state=42, | ||
n_init=1) | ||
km.fit(X) | ||
@pytest.mark.parametrize('data', [X, X_csr], ids=['dense', 'sparse']) | ||
@pytest.mark.parametrize('init', ['random', 'k-means++', centers.copy()]) | ||
def test_k_means_init(data, init): | ||
km = KMeans(init=init, n_clusters=n_clusters, random_state=42, n_init=1) | ||
km.fit(data) | ||
_check_fitted_model(km) | ||
|
||
|
||
|
@@ -315,13 +292,6 @@ def test_k_means_fortran_aligned_data(): | |
assert_array_equal(km.labels_, labels) | ||
|
||
|
||
def test_mb_k_means_plus_plus_init_dense_array(): | ||
mb_k_means = MiniBatchKMeans(init="k-means++", n_clusters=n_clusters, | ||
random_state=42) | ||
mb_k_means.fit(X) | ||
_check_fitted_model(mb_k_means) | ||
|
||
|
||
def test_mb_kmeans_verbose(): | ||
mb_k_means = MiniBatchKMeans(init="k-means++", n_clusters=n_clusters, | ||
random_state=42, verbose=1) | ||
|
@@ -333,49 +303,25 @@ def test_mb_kmeans_verbose(): | |
sys.stdout = old_stdout | ||
|
||
|
||
def test_mb_k_means_plus_plus_init_sparse_matrix(): | ||
mb_k_means = MiniBatchKMeans(init="k-means++", n_clusters=n_clusters, | ||
random_state=42) | ||
mb_k_means.fit(X_csr) | ||
_check_fitted_model(mb_k_means) | ||
|
||
|
||
def test_minibatch_init_with_large_k(): | ||
mb_k_means = MiniBatchKMeans(init='k-means++', init_size=10, n_clusters=20) | ||
# Check that a warning is raised, as the number clusters is larger | ||
# than the init_size | ||
assert_warns(RuntimeWarning, mb_k_means.fit, X) | ||
|
||
|
||
def test_minibatch_k_means_random_init_dense_array(): | ||
# increase n_init to make random init stable enough | ||
mb_k_means = MiniBatchKMeans(init="random", n_clusters=n_clusters, | ||
random_state=42, n_init=10).fit(X) | ||
_check_fitted_model(mb_k_means) | ||
|
||
|
||
def test_minibatch_k_means_random_init_sparse_csr(): | ||
# increase n_init to make random init stable enough | ||
mb_k_means = MiniBatchKMeans(init="random", n_clusters=n_clusters, | ||
random_state=42, n_init=10).fit(X_csr) | ||
_check_fitted_model(mb_k_means) | ||
|
||
|
||
def test_minibatch_k_means_perfect_init_dense_array(): | ||
mb_k_means = MiniBatchKMeans(init=centers.copy(), n_clusters=n_clusters, | ||
random_state=42, n_init=1).fit(X) | ||
_check_fitted_model(mb_k_means) | ||
|
||
|
||
def test_minibatch_k_means_init_multiple_runs_with_explicit_centers(): | ||
mb_k_means = MiniBatchKMeans(init=centers.copy(), n_clusters=n_clusters, | ||
random_state=42, n_init=10) | ||
assert_warns(RuntimeWarning, mb_k_means.fit, X) | ||
|
||
|
||
def test_minibatch_k_means_perfect_init_sparse_csr(): | ||
mb_k_means = MiniBatchKMeans(init=centers.copy(), n_clusters=n_clusters, | ||
random_state=42, n_init=1).fit(X_csr) | ||
@pytest.mark.parametrize('data', [X, X_csr], ids=['dense', 'sparse']) | ||
@pytest.mark.parametrize('init', ["random", 'k-means++', centers.copy()]) | ||
def test_minibatch_k_means_init(data, init): | ||
mb_k_means = MiniBatchKMeans(init=init, n_clusters=n_clusters, | ||
random_state=42, n_init=10) | ||
mb_k_means.fit(data) | ||
_check_fitted_model(mb_k_means) | ||
|
||
|
||
|
@@ -585,64 +531,39 @@ def test_predict(): | |
assert_array_equal(pred, km.labels_) | ||
|
||
|
||
def test_score(): | ||
|
||
km1 = KMeans(n_clusters=n_clusters, max_iter=1, random_state=42, n_init=1) | ||
s1 = km1.fit(X).score(X) | ||
km2 = KMeans(n_clusters=n_clusters, max_iter=10, random_state=42, n_init=1) | ||
s2 = km2.fit(X).score(X) | ||
assert_greater(s2, s1) | ||
|
||
@pytest.mark.parametrize('algo', ['full', 'elkan']) | ||
def test_score(algo): | ||
# Check that fitting k-means with multiple inits gives better score | ||
km1 = KMeans(n_clusters=n_clusters, max_iter=1, random_state=42, n_init=1, | ||
algorithm='elkan') | ||
algorithm=algo) | ||
s1 = km1.fit(X).score(X) | ||
km2 = KMeans(n_clusters=n_clusters, max_iter=10, random_state=42, n_init=1, | ||
algorithm='elkan') | ||
algorithm=algo) | ||
s2 = km2.fit(X).score(X) | ||
assert_greater(s2, s1) | ||
|
||
|
||
def test_predict_minibatch_dense_input(): | ||
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, random_state=40).fit(X) | ||
|
||
# sanity check: predict centroid labels | ||
pred = mb_k_means.predict(mb_k_means.cluster_centers_) | ||
assert_array_equal(pred, np.arange(n_clusters)) | ||
|
||
# sanity check: re-predict labeling for training set samples | ||
pred = mb_k_means.predict(X) | ||
assert_array_equal(mb_k_means.predict(X), mb_k_means.labels_) | ||
|
||
|
||
def test_predict_minibatch_kmeanspp_init_sparse_input(): | ||
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, init='k-means++', | ||
n_init=10).fit(X_csr) | ||
@pytest.mark.parametrize('data', [X, X_csr], ids=['dense', 'sparse']) | ||
@pytest.mark.parametrize('init', ['random', 'k-means++', centers.copy()]) | ||
def test_predict_minibatch(data, init): | ||
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, init=init, | ||
n_init=10, random_state=0).fit(data) | ||
|
||
# sanity check: re-predict labeling for training set samples | ||
assert_array_equal(mb_k_means.predict(X_csr), mb_k_means.labels_) | ||
assert_array_equal(mb_k_means.predict(data), mb_k_means.labels_) | ||
|
||
# sanity check: predict centroid labels | ||
pred = mb_k_means.predict(mb_k_means.cluster_centers_) | ||
assert_array_equal(pred, np.arange(n_clusters)) | ||
|
||
# check that models trained on sparse input also works for dense input at | ||
# predict time | ||
assert_array_equal(mb_k_means.predict(X), mb_k_means.labels_) | ||
|
||
|
||
def test_predict_minibatch_random_init_sparse_input(): | ||
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, init='random', | ||
n_init=10).fit(X_csr) | ||
|
||
# sanity check: re-predict labeling for training set samples | ||
assert_array_equal(mb_k_means.predict(X_csr), mb_k_means.labels_) | ||
|
||
# sanity check: predict centroid labels | ||
pred = mb_k_means.predict(mb_k_means.cluster_centers_) | ||
assert_array_equal(pred, np.arange(n_clusters)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we keep these 2 lines as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did keep them :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in |
||
|
||
@pytest.mark.parametrize('init', ['random', 'k-means++', centers.copy()]) | ||
def test_predict_minibatch_dense_sparse(init): | ||
# check that models trained on sparse input also works for dense input at | ||
# predict time | ||
mb_k_means = MiniBatchKMeans(n_clusters=n_clusters, init=init, | ||
n_init=10, random_state=0).fit(X_csr) | ||
|
||
assert_array_equal(mb_k_means.predict(X), mb_k_means.labels_) | ||
|
||
|
||
|
@@ -694,27 +615,19 @@ def test_fit_transform(): | |
assert_array_almost_equal(X1, X2) | ||
|
||
|
||
def test_predict_equal_labels(): | ||
km = KMeans(random_state=13, n_jobs=1, n_init=1, max_iter=1, | ||
algorithm='full') | ||
km.fit(X) | ||
assert_array_equal(km.predict(X), km.labels_) | ||
|
||
@pytest.mark.parametrize('algo', ['full', 'elkan']) | ||
def test_predict_equal_labels(algo): | ||
km = KMeans(random_state=13, n_jobs=1, n_init=1, max_iter=1, | ||
algorithm='elkan') | ||
algorithm=algo) | ||
km.fit(X) | ||
assert_array_equal(km.predict(X), km.labels_) | ||
|
||
|
||
def test_full_vs_elkan(): | ||
km1 = KMeans(algorithm='full', random_state=13).fit(X) | ||
km2 = KMeans(algorithm='elkan', random_state=13).fit(X) | ||
|
||
km1 = KMeans(algorithm='full', random_state=13) | ||
km2 = KMeans(algorithm='elkan', random_state=13) | ||
|
||
km1.fit(X) | ||
km2.fit(X) | ||
|
||
homogeneity_score(km1.predict(X), km2.predict(X)) == 1.0 | ||
assert homogeneity_score(km1.predict(X), km2.predict(X)) == 1.0 | ||
|
||
|
||
def test_n_init(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we still keep this line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I moved it in a new function :
test_predict_minibatch_dense_sparse
.