Skip to content
23 changes: 11 additions & 12 deletions sklearn/metrics/_plot/tests/test_roc_curve_display.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from sklearn import clone
from sklearn.compose import make_column_transformer
from sklearn.datasets import load_breast_cancer, load_iris
from sklearn.datasets import load_breast_cancer, make_classification
from sklearn.exceptions import NotFittedError
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import RocCurveDisplay, auc, roc_curve
Expand All @@ -16,20 +16,19 @@


@pytest.fixture(scope="module")
def data():
X, y = load_iris(return_X_y=True)
# Avoid introducing test dependencies by mistake.
X.flags.writeable = False
y.flags.writeable = False
def data_binary():
X, y = make_classification(
n_samples=200,
n_features=20,
n_informative=5,
n_redundant=2,
flip_y=0.1,
class_sep=0.8,
random_state=42,
)
return X, y


@pytest.fixture(scope="module")
def data_binary(data):
X, y = data
return X[y < 2], y[y < 2]


@pytest.mark.parametrize("response_method", ["predict_proba", "decision_function"])
@pytest.mark.parametrize("with_sample_weight", [True, False])
@pytest.mark.parametrize("drop_intermediate", [True, False])
Expand Down