Skip to content

Commit cd06652

Browse files
FIX nan bug in BaseLabelPropagation (#19271)
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
1 parent 69d6378 commit cd06652

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

doc/whats_new/v1.0.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ Changelog
150150
for non-English characters. :pr:`18959` by :user:`Zero <Zeroto521>`
151151
and :user:`wstates <wstates>`.
152152

153+
:mod:`sklearn.semi_supervised`
154+
.................................
155+
156+
- |Fix| Avoid NaN during label propagation in
157+
:class:`~sklearn.semi_supervised.LabelPropagation`.
158+
:pr:`19271` by :user:`Zhaowei Wang <ThuWangzw>`.
159+
153160
Code and Documentation Contributors
154161
-----------------------------------
155162

sklearn/semi_supervised/_label_propagation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def fit(self, X, y):
279279
if self._variant == 'propagation':
280280
normalizer = np.sum(
281281
self.label_distributions_, axis=1)[:, np.newaxis]
282+
normalizer[normalizer == 0] = 1
282283
self.label_distributions_ /= normalizer
283284
self.label_distributions_ = np.where(unlabeled,
284285
self.label_distributions_,

sklearn/semi_supervised/tests/test_label_propagation.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,19 @@ def test_convergence_warning():
157157
assert_no_warnings(mdl.fit, X, y)
158158

159159

160-
def test_label_propagation_non_zero_normalizer():
160+
@pytest.mark.parametrize("LabelPropagationCls",
161+
[label_propagation.LabelSpreading,
162+
label_propagation.LabelPropagation])
163+
def test_label_propagation_non_zero_normalizer(LabelPropagationCls):
161164
# check that we don't divide by zero in case of null normalizer
162165
# non-regression test for
163166
# https://github.com/scikit-learn/scikit-learn/pull/15946
167+
# https://github.com/scikit-learn/scikit-learn/issues/9292
164168
X = np.array([[100., 100.], [100., 100.], [0., 0.], [0., 0.]])
165169
y = np.array([0, 1, -1, -1])
166-
mdl = label_propagation.LabelSpreading(kernel='knn',
167-
max_iter=100,
168-
n_neighbors=1)
170+
mdl = LabelPropagationCls(kernel='knn',
171+
max_iter=100,
172+
n_neighbors=1)
169173
assert_no_warnings(mdl.fit, X, y)
170174

171175

0 commit comments

Comments
 (0)