-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
new tests for mean_shift algo #13179
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@ogrisel can help review this |
def test_mean_shift_negative_bandwidth(): | ||
bandwidth = -1 | ||
ms = MeanShift(bandwidth=bandwidth) | ||
msg = \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use parentheses to enclose expressions and split them over multiple lines rather than using \ for line continuation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jnothman this comment is not clear will following statement work?
msg = "bandwidth needs to be greater than zero or None,"
" got -1.000000"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will:
msg = ("bandwidth needs to be greater than zero or None,"
" got -1.000000")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed
|
||
def test_seeds(): | ||
ms = MeanShift(seeds=None) | ||
_ = ms.fit(X).labels_ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do you get labels_?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
assert_raise_message(ValueError, msg, ms.fit, X) | ||
|
||
|
||
def test_seeds(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't get what this is testing. Checking that parameters are maintained should usually be covered by common tests not tests for each specific estimator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
labels = ms.fit(X).labels_ | ||
labels_unique = np.unique(labels) | ||
n_clusters_ = len(labels_unique) | ||
assert_equal(n_clusters_ > n_clusters, True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use bare assert as with seeds above
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
n_clusters_ = len(labels_unique) | ||
assert_equal(n_clusters_ > n_clusters, True) | ||
|
||
cluster_centers, labels = mean_shift(X, bandwidth=bandwidth, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than repeat the code, please use pytest.mark.parameterize to test multiple settings of bandwidth
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
changed to use
pytest.mark.parameterize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jnothman please review
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
b018e99
to
4cf6413
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I confirm this covers untested lines.
bandwidth = -1 | ||
ms = MeanShift(bandwidth=bandwidth) | ||
msg = ("bandwidth needs to be greater than zero or None," | ||
" got -1.000000") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This whitespace looks like an error in the code raising the message. Please change the code to have a single space between the comma and "got"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unresolved. Please fix the error message in mean_shift_.py
(1.2, True, 3), | ||
(1.2, False, 4) | ||
]) | ||
def test_eval(bandwidth, cluster_all, expected): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what do you mean by calling this "eval"? Can't we just paramertrize test_mean_shift
above, rather than adding a new test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But ideally we should also test that cluster_all=False
is actually effective at allowing some points to be left unclustered. Create a dataset where a point will be left with label -1 to test this properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jnothman fixed as suggested
4cf6413
to
dea8840
Compare
Please merge the current master |
def test_mean_shift(): | ||
@pytest.mark.parametrize("bandwidth, cluster_all, expected, " | ||
"first_cluster_label", | ||
[(1.2, True, 3, 0), (1.2, False, 4, -1)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Much clearer, thanks!
bandwidth = -1 | ||
ms = MeanShift(bandwidth=bandwidth) | ||
msg = ("bandwidth needs to be greater than zero or None," | ||
" got -1.000000") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unresolved. Please fix the error message in mean_shift_.py
bb1dd95
to
f40648d
Compare
@jnothman fixed the comments |
Thanks @rajdeepd |
@jnothman how do we get this pull request merged into master? |
4 days is not long to wait for a second review, @rajdeepd... hopefully one will come soon. |
|
||
cluster_centers, labels = mean_shift(X, bandwidth=bandwidth) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing this means we are not testing the mean_shift
function directly anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we are testing using
ms = MeanShift(bandwidth=bandwidth, cluster_all=cluster_all)
labels = ms.fit(X).labels_
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The testing of mean_shift
should be independent of ms.fit
. At the moment, ms.fit
calls mean_shift
, but we do not know how the code base will change.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@thomasjpfan do we need another test for mean_shift?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leaving the original test here will sufficiently test mean_shift
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@thomasjpfan added test for mean_shift as well
ms = MeanShift(bandwidth=bandwidth) | ||
msg = ("bandwidth needs to be greater than zero or None," | ||
" got -1.000000") | ||
assert_raise_message(ValueError, msg, ms.fit, X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are moving to using pytest.raises
:
msg = (r"bandwidth needs to be greater than zero or None,"
r" got -1\.000000")
with pytest.raises(ValueError, match=msg):
ms.fit(X)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@thomasjpfan fixed
71df239
to
1b9f928
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM otherwise
n_clusters_ = len(labels_unique) | ||
assert_equal(n_clusters_, n_clusters) | ||
cluster_centers, labels_mean_shift = mean_shift(X, cluster_all=cluster_all) | ||
print(cluster_centers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please remove
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
removed
@@ -36,23 +37,36 @@ def test_estimate_bandwidth_1sample(): | |||
# Test estimate_bandwidth when n_samples=1 and quantile<1, so that | |||
# n_neighbors is set to 1. | |||
bandwidth = estimate_bandwidth(X, n_samples=1, quantile=0.3) | |||
assert_array_almost_equal(bandwidth, 0., decimal=5) | |||
assert_equal(bandwidth, 0.) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could just be assert a == b
then
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated @NicolasHug
1b9f928
to
aa17ea1
Compare
Thanks @rajdeepd |
This reverts commit 67f53dc.
This reverts commit 67f53dc.
Reference Issues/PRs
none
What does this implement/fix? Explain your changes.
Add test cases to cover un-tested portions of mean_shift.py
Any other comments?
no other comments