Skip to content

FIX nan bug in BaseLabelPropagation #19271

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Feb 1, 2021

Conversation

ThuWangzw
Copy link
Contributor

Reference Issues/PRs

Fixes #9292.

What does this implement/fix? Explain your changes.

Label distribution of some samples may be all zero, which causes nan error during normalization. Add normalizer[normalizer == 0] = 1 will fix this bug.

Copy link
Member

@jjerphan jjerphan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ThuWangzw, thanks for this contribution.

Can you add a regression test which assert the correctness of this fix?

@ThuWangzw
Copy link
Contributor Author

Thanks for your attention and I have added the test. (By the way, I think this bug is quite similar to #15946.)

@jjerphan
Copy link
Member

Indeed, you just complete the fix of #15946 for LabelPropagation.

To remove some code duplication, I would suggest extending the test introduced by #15946 using test parametrisation on the class. This can be done with something like:

@pytest.mark.parametrize("label_propagation_class",
                         [label_propagation.LabelSpreading,
                          label_propagation.LabelPropagation])
def test_label_propagation_non_zero_normalizer(label_propagation_class):
    # check that we don't divide by zero in case of null normalizer
    # non-regression test for
    # https://github.com/scikit-learn/scikit-learn/pull/15946
    # https://github.com/scikit-learn/scikit-learn/pull/19271
    X = np.array([[100., 100.], [100., 100.], [0., 0.], [0., 0.]])
    y = np.array([0, 1, -1, -1])
    mdl = label_propagation_class(kernel='knn',
                                  max_iter=100,
                                  n_neighbors=1)
    assert_no_warnings(mdl.fit, X, y)

@ThuWangzw
Copy link
Contributor Author

I replaced the duplicated testing codes with yours :). Thank you!

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good. Thank you for the PR @ThuWangzw !

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
Please add an entry to the change log at doc/whats_new/v1.0.rst with tag |Fix|. Like the other entries there, please reference this pull request with :pr: and credit yourself (and other contributors if applicable) with :user:.

@jjerphan
Copy link
Member

Thanks for the PR @ThuWangzw !

@ThuWangzw
Copy link
Contributor Author

Actually this is my first time contributing to open source community. Thanks for your suggestions @jjerphan and @thomasjpfan !

@jjerphan
Copy link
Member

You're welcome @ThuWangzw; I hope you enjoy this first contribution!

Copy link
Member

@rth rth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks LGTM.

I'm still surprised there is no more elegant way of doing division with the 0/0 special case without masking.

@rth rth changed the title Bug fix: nan bug in BaseLabelPropagation FIX nan bug in BaseLabelPropagation Feb 1, 2021
@rth rth merged commit cd06652 into scikit-learn:main Feb 1, 2021
@glemaitre glemaitre added the To backport PR merged in master that need a backport to a release branch defined based on the milestone. label Feb 11, 2021
@glemaitre glemaitre added this to the 0.24.2 milestone Feb 11, 2021
@glemaitre glemaitre mentioned this pull request Apr 22, 2021
12 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:semi_supervised To backport PR merged in master that need a backport to a release branch defined based on the milestone.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] Label propagation sometimes produces label_distributions that contain Nan.
5 participants