Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion sklearn/svm/_classes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

import warnings
from numbers import Integral, Real

import numpy as np
Expand Down Expand Up @@ -690,6 +691,10 @@ class SVC(BaseSVC):
5-fold cross-validation, and `predict_proba` may be inconsistent with
`predict`. Read more in the :ref:`User Guide <scores_probabilities>`.

..deprecated:: version 1.8 parameter probability will be deprecated in future
versions, use CalibratedClassifierCV(SVC(), ensemble=False) instead
of SVC(probability=True)

tol : float, default=1e-3
Tolerance for stopping criterion.

Expand Down Expand Up @@ -867,7 +872,7 @@ def __init__(
gamma="scale",
coef0=0.0,
shrinking=True,
probability=False,
probability="deprecated",
tol=1e-3,
cache_size=200,
class_weight=None,
Expand All @@ -877,6 +882,16 @@ def __init__(
break_ties=False,
random_state=None,
):
if probability != "deprecated":
warnings.warn(
"parameter `probability` will be deprecated in version 1.8, "
"use `CalibratedClassifierCV(SVC(), ensemble=False)` instead "
"of `SVC(probability=True)`",
FutureWarning,
)
else:
probability = False

super().__init__(
kernel=kernel,
degree=degree,
Expand Down Expand Up @@ -951,6 +966,10 @@ class NuSVC(BaseSVC):
5-fold cross-validation, and `predict_proba` may be inconsistent with
`predict`. Read more in the :ref:`User Guide <scores_probabilities>`.

..deprecated:: version 1.8 parameter probability will be deprecated in future
versions, use CalibratedClassifierCV(NuSVC(), ensemble=False) instead
of NuSVC(probability=True)

tol : float, default=1e-3
Tolerance for stopping criterion.

Expand Down Expand Up @@ -1140,6 +1159,15 @@ def __init__(
break_ties=False,
random_state=None,
):
if probability != "deprecated":
warnings.warn(
"parameter `probability` will be deprecated in version 1.8, "
"use `CalibratedClassifierCV(NuSVC(), ensemble=False)` instead "
"of `NuSVC(probability=True)`",
FutureWarning,
)
else:
probability = False
super().__init__(
kernel=kernel,
degree=degree,
Expand Down
16 changes: 16 additions & 0 deletions sklearn/svm/tests/test_svm.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def test_tweak_params():
# XXX: this test is thread-unsafe because it uses probability=True:
# https://github.com/scikit-learn/scikit-learn/issues/31885
@pytest.mark.thread_unsafe
@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_probability(global_random_seed):
# Predict probabilities using SVC
# This uses cross validation, so we use a slightly bigger testing set.
Expand Down Expand Up @@ -764,6 +765,7 @@ def test_svc_nonfinite_params(global_random_seed):
clf.fit(X, y)


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_unicode_kernel(global_random_seed):
# Test that a unicode kernel name does not cause a TypeError
iris = get_iris_dataset(global_random_seed)
Expand Down Expand Up @@ -1055,6 +1057,7 @@ def test_linearsvc_verbose():
# XXX: this test is thread-unsafe because it uses probability=True:
# https://github.com/scikit-learn/scikit-learn/issues/31885
@pytest.mark.thread_unsafe
@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_svc_clone_with_callable_kernel():
iris = get_iris_dataset(42)

Expand Down Expand Up @@ -1102,6 +1105,7 @@ def test_svc_bad_kernel():
# XXX: this test is thread-unsafe because it uses probability=True:
# https://github.com/scikit-learn/scikit-learn/issues/31885
@pytest.mark.thread_unsafe
@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_libsvm_convergence_warnings(global_random_seed):
a = svm.SVC(
kernel=lambda x, y: np.dot(x, y.T),
Expand Down Expand Up @@ -1135,6 +1139,7 @@ def test_unfitted():
# https://github.com/scikit-learn/scikit-learn/issues/31885
@pytest.mark.thread_unsafe
@pytest.mark.filterwarnings("ignore::sklearn.exceptions.ConvergenceWarning")
@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_consistent_proba(global_random_seed):
a = svm.SVC(probability=True, max_iter=1, random_state=global_random_seed)
proba_1 = a.fit(X, Y).predict_proba(X)
Expand Down Expand Up @@ -1190,6 +1195,7 @@ def test_lsvc_intercept_scaling_zero():
assert lsvc.intercept_ == 0.0


@pytest.mark.filterwarnings("ignore::FutureWarning")
def test_hasattr_predict_proba(global_random_seed):
iris = get_iris_dataset(global_random_seed)

Expand Down Expand Up @@ -1534,3 +1540,13 @@ def test_svm_with_infinite_C(Estimator, make_dataset, C_inf, global_random_seed)
estimator_C_large = Estimator(C=1e10).fit(X, y)

assert_allclose(estimator_C_large.predict(X), estimator_C_inf.predict(X))


@pytest.mark.parametrize(
"Estimator, name",
[(svm.SVC, "SVC"), (svm.NuSVC, "NuSVC")],
)
def test_probability_raises_futurewarning(Estimator, name):
X, y = make_classification()
with pytest.warns(FutureWarning):
Estimator(probability=True).fit(X, y)
Loading