Skip to content

ENH Allows target to be pandas nullable dtypes #25638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/v1.3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ Changelog
during `transform` with no prior call to `fit` or `fit_transform`.
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.

- |Enhancement| :func:`utils.multiclass.type_of_target` can identify pandas
nullable data types as classification targets. :pr:`25638` by `Thomas Fan`_.

:mod:`sklearn.semi_supervised`
..............................

Expand Down
18 changes: 18 additions & 0 deletions sklearn/metrics/tests/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,24 @@ def test_confusion_matrix_dtype():
assert cm[1, 1] == -2


@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
def test_confusion_matrix_pandas_nullable(dtype):
"""Checks that confusion_matrix works with pandas nullable dtypes.

Non-regression test for gh-25635.
"""
pd = pytest.importorskip("pandas")

y_ndarray = np.array([1, 0, 0, 1, 0, 1, 1, 0, 1])
y_true = pd.Series(y_ndarray, dtype=dtype)
y_predicted = pd.Series([0, 0, 1, 1, 0, 1, 1, 1, 1], dtype="int64")

output = confusion_matrix(y_true, y_predicted)
expected_output = confusion_matrix(y_ndarray, y_predicted)

assert_array_equal(output, expected_output)


def test_classification_report_multiclass():
# Test performance report
iris = datasets.load_iris()
Expand Down
16 changes: 16 additions & 0 deletions sklearn/preprocessing/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ def test_label_binarizer_set_label_encoding():
assert_array_equal(lb.inverse_transform(got), inp)


@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
def test_label_binarizer_pandas_nullable(dtype):
"""Checks that LabelBinarizer works with pandas nullable dtypes.

Non-regression test for gh-25637.
"""
pd = pytest.importorskip("pandas")
from sklearn.preprocessing import LabelBinarizer

y_true = pd.Series([1, 0, 0, 1, 0, 1, 1, 0, 1], dtype=dtype)
lb = LabelBinarizer().fit(y_true)
y_out = lb.transform([1, 0])

assert_array_equal(y_out, [[1], [0]])


@ignore_warnings
def test_label_binarizer_errors():
# Check that invalid arguments yield ValueError
Expand Down
35 changes: 29 additions & 6 deletions sklearn/utils/multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,14 +155,25 @@ def is_multilabel(y):
if hasattr(y, "__array__") or isinstance(y, Sequence) or is_array_api:
# DeprecationWarning will be replaced by ValueError, see NEP 34
# https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
check_y_kwargs = dict(
accept_sparse=True,
allow_nd=True,
force_all_finite=False,
ensure_2d=False,
ensure_min_samples=0,
ensure_min_features=0,
)
with warnings.catch_warnings():
warnings.simplefilter("error", np.VisibleDeprecationWarning)
try:
y = xp.asarray(y)
except (np.VisibleDeprecationWarning, ValueError):
y = check_array(y, dtype=None, **check_y_kwargs)
except (np.VisibleDeprecationWarning, ValueError) as e:
if str(e).startswith("Complex data not supported"):
raise

# dtype=object should be provided explicitly for ragged arrays,
# see NEP 34
y = xp.asarray(y, dtype=object)
y = check_array(y, dtype=object, **check_y_kwargs)

if not (hasattr(y, "shape") and y.ndim == 2 and y.shape[1] > 1):
return False
Expand Down Expand Up @@ -302,15 +313,27 @@ def type_of_target(y, input_name=""):
# https://numpy.org/neps/nep-0034-infer-dtype-is-object.html
# We therefore catch both deprecation (NumPy < 1.24) warning and
# value error (NumPy >= 1.24).
check_y_kwargs = dict(
accept_sparse=True,
allow_nd=True,
force_all_finite=False,
ensure_2d=False,
ensure_min_samples=0,
ensure_min_features=0,
)

with warnings.catch_warnings():
warnings.simplefilter("error", np.VisibleDeprecationWarning)
if not issparse(y):
try:
y = xp.asarray(y)
except (np.VisibleDeprecationWarning, ValueError):
y = check_array(y, dtype=None, **check_y_kwargs)
except (np.VisibleDeprecationWarning, ValueError) as e:
if str(e).startswith("Complex data not supported"):
raise

# dtype=object should be provided explicitly for ragged arrays,
# see NEP 34
y = xp.asarray(y, dtype=object)
y = check_array(y, dtype=object, **check_y_kwargs)

# The old sequence of sequences format
try:
Expand Down
36 changes: 36 additions & 0 deletions sklearn/utils/tests/test_multiclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,42 @@ def test_type_of_target_pandas_sparse():
type_of_target(y)


def test_type_of_target_pandas_nullable():
"""Check that type_of_target works with pandas nullable dtypes."""
pd = pytest.importorskip("pandas")

for dtype in ["Int32", "Float32"]:
y_true = pd.Series([1, 0, 2, 3, 4], dtype=dtype)
assert type_of_target(y_true) == "multiclass"

y_true = pd.Series([1, 0, 1, 0], dtype=dtype)
assert type_of_target(y_true) == "binary"

y_true = pd.DataFrame([[1.4, 3.1], [3.1, 1.4]], dtype="Float32")
assert type_of_target(y_true) == "continuous-multioutput"

y_true = pd.DataFrame([[0, 1], [1, 1]], dtype="Int32")
assert type_of_target(y_true) == "multilabel-indicator"

y_true = pd.DataFrame([[1, 2], [3, 1]], dtype="Int32")
assert type_of_target(y_true) == "multiclass-multioutput"


@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
def test_unique_labels_pandas_nullable(dtype):
"""Checks that unique_labels work with pandas nullable dtypes.

Non-regression test for gh-25634.
"""
pd = pytest.importorskip("pandas")

y_true = pd.Series([1, 0, 0, 1, 0, 1, 1, 0, 1], dtype=dtype)
y_predicted = pd.Series([0, 0, 1, 1, 0, 1, 1, 1, 1], dtype="int64")

labels = unique_labels(y_true, y_predicted)
assert_array_equal(labels, [0, 1])


def test_class_distribution():
y = np.array(
[
Expand Down