Skip to content

Fix for Multiclass SVC.fit fails if sample_weight zeros out a class #26593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions sklearn/svm/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from . import _liblinear as liblinear # type: ignore
from . import _libsvm_sparse as libsvm_sparse # type: ignore
from ..base import BaseEstimator, ClassifierMixin
from ..base import _fit_context
from ..preprocessing import LabelEncoder
from ..utils.multiclass import _ovr_decision_function
from ..utils import check_array, check_random_state
Expand Down Expand Up @@ -144,7 +143,6 @@ def _more_tags(self):
# Used by cross_val_score.
return {"pairwise": self.kernel == "precomputed"}

@_fit_context(prefer_skip_nested_validation=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an unintended change, should be put back

def fit(self, X, y, sample_weight=None):
"""Fit the SVM model according to the given training data.

Expand Down Expand Up @@ -178,6 +176,8 @@ def fit(self, X, y, sample_weight=None):
If X is a dense array, then the other methods will not support sparse
matrices as input.
"""
self._validate_params()

Comment on lines +179 to +180
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, should be removed

rnd = check_random_state(self.random_state)

sparse = sp.isspmatrix(X)
Expand Down Expand Up @@ -283,6 +283,21 @@ def fit(self, X, y, sample_weight=None):
else:
self.n_iter_ = self._num_iter.item()

# Deal with zero weights
# Remove classes associated with zero weights
# Only for multi-class
if hasattr(self, "classes_"):
zero_weight_index = sample_weight == 0
if len(zero_weight_index) > 0 and len(self.classes_) > 2:
X = X[~zero_weight_index]
y = y[~zero_weight_index]
sample_weight = sample_weight[~zero_weight_index]
y = self._validate_targets(y) # Changing number of classes and targets

warnings.warn(
"Removed all classes with zero sample weights",
)

return self

def _validate_targets(self, y):
Expand Down
18 changes: 18 additions & 0 deletions sklearn/svm/tests/test_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,24 @@ def test_svm_equivalence_sample_weight_C():
assert_allclose(dual_coef_no_weight, clf.dual_coef_)


def test_svm_multiclass_zero_sample_weights():
# test that class with zero sample weight has no effect
# on the model trained params with and without the class
X = np.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]])
y = [0, 1, 2]
w = [0.0, 1.0, 1.0]
clf = svm.SVC().fit(X, y, w)
n_support_zero = clf.n_support_
support_vectors_zero = clf.support_vectors_
X = np.array([[1.0, 0.0], [0.0, 1.0]])
y = [1, 2]
w = [1.0, 1.0]
clf = svm.SVC().fit(X, y, w)

assert_allclose(n_support_zero, clf.n_support_)
assert_allclose(support_vectors_zero, clf.support_vectors_)


@pytest.mark.parametrize(
"Estimator, err_msg",
[
Expand Down