Skip to content

ENH Array API support for LabelEncoder #27381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
4869c0d
ENH Array API support for LabelEncoder
OmarManzoor Sep 15, 2023
7fbd458
Add changelog
OmarManzoor Sep 15, 2023
ec6ccc6
Add tests for array api functions
OmarManzoor Sep 15, 2023
a9d94ea
Merge branch 'main' into label_encoder_array_api
OmarManzoor Sep 23, 2023
43b039d
Updates: PR suggestions
OmarManzoor Sep 23, 2023
6198558
Merge branch 'main' into label_encoder_array_api
OmarManzoor Apr 4, 2024
cfdabeb
Fix dtype_name parameter
OmarManzoor Apr 4, 2024
cfcabd2
Merge branch 'main' into label_encoder_array_api
Apr 12, 2024
23ee510
Updates as suggested in review
Apr 12, 2024
fa0e27c
Merge branch 'main' into label_encoder_array_api
May 3, 2024
6177475
Revert changes is estimator_checks
May 3, 2024
a21a490
Improve the tests and handle device in _in1d
OmarManzoor May 3, 2024
b09b57b
Fix missing device specification and explicit conversion to numpy
ogrisel May 3, 2024
0544c32
Fix _isin to work with Array API inputs
ogrisel May 3, 2024
7cbbc20
Merge branch 'main' into label_encoder_array_api
ogrisel May 3, 2024
a34138b
Fix the errors, make searchsorted a helper function
OmarManzoor May 6, 2024
58c5aa0
Merge branch 'main' into label_encoder_array_api
OmarManzoor May 6, 2024
beb036a
Add array_api_support tag
OmarManzoor May 6, 2024
34c2d92
Updates: according to some pr suggestions
OmarManzoor May 7, 2024
bdb2d7e
Merge branch 'main' into label_encoder_array_api
ogrisel May 7, 2024
db32acf
Use xp.isdtype(values.dtype, "numeric") directly
ogrisel May 7, 2024
a593478
Update changelog
OmarManzoor May 7, 2024
22fa611
Update docstring for inverse transform
OmarManzoor May 7, 2024
f814441
Change array-like to array
OmarManzoor May 7, 2024
b5350ea
Merge branch 'main' into label_encoder_array_api
OmarManzoor May 8, 2024
8ce860d
Update the changelog definition to make it consistent
OmarManzoor May 8, 2024
fae25aa
Revert and update parameter and return type names
OmarManzoor May 10, 2024
e1bca48
Merge branch 'main' into label_encoder_array_api
OmarManzoor May 15, 2024
30f026b
Merge remote-tracking branch 'upstream/main' into label_encoder_array…
OmarManzoor May 16, 2024
dbf233a
Updates: Address further PR suggestions
OmarManzoor May 16, 2024
7500c2f
Minor adjustment
OmarManzoor May 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ See :ref:`array_api` for more details.

**Classes:**

-
- :class:`preprocessing.LabelEncoder` now supports Array API compatible inputs.
:pr:`27381` by :user:`Omar Salman <OmarManzoor>`.

Metadata Routing
----------------
Expand Down
23 changes: 15 additions & 8 deletions sklearn/preprocessing/_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ..base import BaseEstimator, TransformerMixin, _fit_context
from ..utils import column_or_1d
from ..utils._array_api import _setdiff1d, device, get_namespace
from ..utils._encode import _encode, _unique
from ..utils._param_validation import Interval, validate_params
from ..utils.multiclass import type_of_target, unique_labels
Expand Down Expand Up @@ -129,10 +130,11 @@ def transform(self, y):
Labels as normalized encodings.
"""
check_is_fitted(self)
xp, _ = get_namespace(y)
y = column_or_1d(y, dtype=self.classes_.dtype, warn=True)
# transform of empty array is empty array
if _num_samples(y) == 0:
return np.array([])
return xp.asarray([])

return _encode(y, uniques=self.classes_)

Expand All @@ -141,7 +143,7 @@ def inverse_transform(self, y):

Parameters
----------
y : ndarray of shape (n_samples,)
y : array-like of shape (n_samples,)
Target values.

Returns
Expand All @@ -150,19 +152,24 @@ def inverse_transform(self, y):
Original encoding.
"""
check_is_fitted(self)
xp, _ = get_namespace(y)
y = column_or_1d(y, warn=True)
# inverse transform of empty array is empty array
if _num_samples(y) == 0:
return np.array([])
return xp.asarray([])

diff = np.setdiff1d(y, np.arange(len(self.classes_)))
if len(diff):
diff = _setdiff1d(
ar1=y,
ar2=xp.arange(self.classes_.shape[0], device=device(y)),
xp=xp,
)
if diff.shape[0]:
raise ValueError("y contains previously unseen labels: %s" % str(diff))
y = np.asarray(y)
return self.classes_[y]
y = xp.asarray(y)
return xp.take(self.classes_, y, axis=0)

def _more_tags(self):
return {"X_types": ["1dlabels"]}
return {"X_types": ["1dlabels"], "array_api_support": True}


class LabelBinarizer(TransformerMixin, BaseEstimator, auto_wrap_output_keys=None):
Expand Down
52 changes: 50 additions & 2 deletions sklearn/preprocessing/tests/test_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
from scipy.sparse import issparse

from sklearn import datasets
from sklearn import config_context, datasets
from sklearn.preprocessing._label import (
LabelBinarizer,
LabelEncoder,
Expand All @@ -11,7 +11,16 @@
_inverse_binarize_thresholding,
label_binarize,
)
from sklearn.utils._testing import assert_array_equal, ignore_warnings
from sklearn.utils._array_api import (
_convert_to_numpy,
get_namespace,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._testing import (
_array_api_for_tests,
assert_array_equal,
ignore_warnings,
)
from sklearn.utils.fixes import (
COO_CONTAINERS,
CSC_CONTAINERS,
Expand Down Expand Up @@ -697,3 +706,42 @@ def test_label_encoders_do_not_have_set_output(encoder):
y_encoded_with_kwarg = encoder.fit_transform(y=["a", "b", "c"])
y_encoded_positional = encoder.fit_transform(["a", "b", "c"])
assert_array_equal(y_encoded_with_kwarg, y_encoded_positional)


@pytest.mark.parametrize(
"array_namespace, device, dtype", yield_namespace_device_dtype_combinations()
)
@pytest.mark.parametrize(
"y",
[
np.array([2, 1, 3, 1, 3]),
np.array([1, 1, 4, 5, -1, 0]),
np.array([3, 5, 9, 5, 9, 3]),
],
)
def test_label_encoder_array_api_compliance(y, array_namespace, device, dtype):
xp = _array_api_for_tests(array_namespace, device)
xp_y = xp.asarray(y, device=device)
with config_context(array_api_dispatch=True):
xp_label = LabelEncoder()
np_label = LabelEncoder()
xp_label = xp_label.fit(xp_y)
xp_transformed = xp_label.transform(xp_y)
xp_inv_transformed = xp_label.inverse_transform(xp_transformed)
np_label = np_label.fit(y)
np_transformed = np_label.transform(y)
assert get_namespace(xp_transformed)[0].__name__ == xp.__name__
assert get_namespace(xp_inv_transformed)[0].__name__ == xp.__name__
assert get_namespace(xp_label.classes_)[0].__name__ == xp.__name__
assert_array_equal(_convert_to_numpy(xp_transformed, xp), np_transformed)
assert_array_equal(_convert_to_numpy(xp_inv_transformed, xp), y)
assert_array_equal(_convert_to_numpy(xp_label.classes_, xp), np_label.classes_)

xp_label = LabelEncoder()
np_label = LabelEncoder()
xp_transformed = xp_label.fit_transform(xp_y)
np_transformed = np_label.fit_transform(y)
assert get_namespace(xp_transformed)[0].__name__ == xp.__name__
assert get_namespace(xp_label.classes_)[0].__name__ == xp.__name__
assert_array_equal(_convert_to_numpy(xp_transformed, xp), np_transformed)
assert_array_equal(_convert_to_numpy(xp_label.classes_, xp), np_label.classes_)
123 changes: 123 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,11 @@ def unique_counts(self, x):
def unique_values(self, x):
return numpy.unique(x)

def unique_all(self, x):
return numpy.unique(
x, return_index=True, return_inverse=True, return_counts=True
)

def concat(self, arrays, *, axis=None):
return numpy.concatenate(arrays, axis=axis)

Expand Down Expand Up @@ -839,3 +844,121 @@ def indexing_dtype(xp):
# TODO: once sufficiently adopted, we might want to instead rely on the
# newer inspection API: https://github.com/data-apis/array-api/issues/640
return xp.asarray(0).dtype


def _searchsorted(xp, a, v, *, side="left", sorter=None):
# Temporary workaround needed as long as searchsorted is not widely
# adopted by implementers of the Array API spec. This is a quite
# recent addition to the spec:
# https://data-apis.org/array-api/latest/API_specification/generated/array_api.searchsorted.html # noqa
if hasattr(xp, "searchsorted"):
return xp.searchsorted(a, v, side=side, sorter=sorter)

a_np = _convert_to_numpy(a, xp=xp)
v_np = _convert_to_numpy(v, xp=xp)
indices = numpy.searchsorted(a_np, v_np, side=side, sorter=sorter)
return xp.asarray(indices, device=device(a))


def _setdiff1d(ar1, ar2, xp, assume_unique=False):
"""Find the set difference of two arrays.

Return the unique values in `ar1` that are not in `ar2`.
"""
if _is_numpy_namespace(xp):
return xp.asarray(
numpy.setdiff1d(
ar1=ar1,
ar2=ar2,
assume_unique=assume_unique,
)
)

if assume_unique:
ar1 = xp.reshape(ar1, (-1,))
else:
ar1 = xp.unique_values(ar1)
ar2 = xp.unique_values(ar2)
return ar1[_in1d(ar1=ar1, ar2=ar2, xp=xp, assume_unique=True, invert=True)]


def _isin(element, test_elements, xp, assume_unique=False, invert=False):
"""Calculates ``element in test_elements``, broadcasting over `element`
only.

Returns a boolean array of the same shape as `element` that is True
where an element of `element` is in `test_elements` and False otherwise.
"""
if _is_numpy_namespace(xp):
return xp.asarray(
numpy.isin(
element=element,
test_elements=test_elements,
assume_unique=assume_unique,
invert=invert,
)
)

original_element_shape = element.shape
element = xp.reshape(element, (-1,))
test_elements = xp.reshape(test_elements, (-1,))
return xp.reshape(
_in1d(
ar1=element,
ar2=test_elements,
xp=xp,
assume_unique=assume_unique,
invert=invert,
),
original_element_shape,
)


# Note: This is a helper for the functions `_isin` and
# `_setdiff1d`. It is not meant to be called directly.
def _in1d(ar1, ar2, xp, assume_unique=False, invert=False):
"""Checks whether each element of an array is also present in a
second array.

Returns a boolean array the same length as `ar1` that is True
where an element of `ar1` is in `ar2` and False otherwise.

This function has been adapted using the original implementation
present in numpy:
https://github.com/numpy/numpy/blob/v1.26.0/numpy/lib/arraysetops.py#L524-L758
"""
xp, _ = get_namespace(ar1, ar2, xp=xp)

# This code is run to make the code significantly faster
if ar2.shape[0] < 10 * ar1.shape[0] ** 0.145:
if invert:
mask = xp.ones(ar1.shape[0], dtype=xp.bool, device=device(ar1))
for a in ar2:
mask &= ar1 != a
else:
mask = xp.zeros(ar1.shape[0], dtype=xp.bool, device=device(ar1))
for a in ar2:
mask |= ar1 == a
return mask

if not assume_unique:
ar1, rev_idx = xp.unique_inverse(ar1)
ar2 = xp.unique_values(ar2)

ar = xp.concat((ar1, ar2))
device_ = device(ar)
# We need this to be a stable sort.
order = xp.argsort(ar, stable=True)
reverse_order = xp.argsort(order, stable=True)
sar = xp.take(ar, order, axis=0)
if invert:
bool_ar = sar[1:] != sar[:-1]
else:
bool_ar = sar[1:] == sar[:-1]
flag = xp.concat((bool_ar, xp.asarray([invert], device=device_)))
ret = xp.take(flag, reverse_order, axis=0)

if assume_unique:
return ret[: ar1.shape[0]]
else:
return xp.take(ret, rev_idx, axis=0)
Loading
Loading