Skip to content

Commit 50d3fe9

Browse files
ngshyaglemaitre
authored andcommitted
FIX avoid division by 0 warning in LabelPropagation (#15946)
1 parent afe6e51 commit 50d3fe9

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

sklearn/semi_supervised/_label_propagation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def fit(self, X, y):
290290
self.n_iter_ += 1
291291

292292
normalizer = np.sum(self.label_distributions_, axis=1)[:, np.newaxis]
293+
normalizer[normalizer == 0] = 1
293294
self.label_distributions_ /= normalizer
294295

295296
# set the transduction item

sklearn/semi_supervised/tests/test_label_propagation.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,18 @@ def test_convergence_warning():
157157
assert_no_warnings(mdl.fit, X, y)
158158

159159

160+
def test_label_propagation_non_zero_normalizer():
161+
# check that we don't divide by zero in case of null normalizer
162+
# non-regression test for
163+
# https://github.com/scikit-learn/scikit-learn/pull/15946
164+
X = np.array([[100., 100.], [100., 100.], [0., 0.], [0., 0.]])
165+
y = np.array([0, 1, -1, -1])
166+
mdl = label_propagation.LabelSpreading(kernel='knn',
167+
max_iter=100,
168+
n_neighbors=1)
169+
assert_no_warnings(mdl.fit, X, y)
170+
171+
160172
def test_predict_sparse_callable_kernel():
161173
# This is a non-regression test for #15866
162174

0 commit comments

Comments
 (0)