Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f6ff0ed
ENH Add array api support for PolynomialFeatures
OmarManzoor Jun 17, 2025
63111d4
Add benchmark
OmarManzoor Jun 18, 2025
e9e1b16
Add benchmark
OmarManzoor Jun 18, 2025
17fa68a
Remove benchmark file for testing
OmarManzoor Jun 18, 2025
a72f6e9
Add benchmark again for testing
OmarManzoor Jun 18, 2025
c58ad46
Remove benchmark file for testing
OmarManzoor Jun 18, 2025
1518126
Add in documentation
OmarManzoor Jun 18, 2025
2e96584
Add changelog
OmarManzoor Jun 18, 2025
a4f45b9
Merge branch 'main' into array-api-poly-features
OmarManzoor Jun 19, 2025
a9055f5
Refactor code in supported_float_dtypes
OmarManzoor Jun 19, 2025
e38629f
Merge branch 'main' into array-api-poly-features
OmarManzoor Jun 19, 2025
7aaa83b
Update the supported float dtypes function
OmarManzoor Jun 19, 2025
11906cc
Merge branch 'main' into array-api-poly-features
OmarManzoor Jun 20, 2025
6620df5
Add device check in test
OmarManzoor Jun 20, 2025
596e10c
Add array api tag
OmarManzoor Jun 23, 2025
d171ea1
Address PR suggestions
OmarManzoor Jun 26, 2025
e9fdca1
Minor updates
OmarManzoor Jun 26, 2025
f9f83b8
Merge branch 'main' into array-api-poly-features
OmarManzoor Jun 26, 2025
66084a7
Minor fix
OmarManzoor Jun 26, 2025
dd328c5
Sync with main
OmarManzoor Jun 26, 2025
5618b80
Add desc in test
OmarManzoor Jun 26, 2025
95894a1
Address PR suggestions
OmarManzoor Jun 27, 2025
789aa09
Assert dtype
OmarManzoor Jun 27, 2025
8ec11ef
Add a test for supported_float_types
OmarManzoor Jul 1, 2025
33a0be1
Merge branch 'main' into array-api-poly-features
OmarManzoor Jul 1, 2025
ef7e9ee
Improve the checking of dtypes rather than str in the test for suppor…
OmarManzoor Jul 1, 2025
4804b43
Update sklearn/preprocessing/_polynomial.py
OmarManzoor Jul 1, 2025
ebbf4f8
Update sklearn/preprocessing/tests/test_polynomial.py
OmarManzoor Jul 1, 2025
8e488f0
Update sklearn/preprocessing/tests/test_polynomial.py
OmarManzoor Jul 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ Estimators
- :class:`preprocessing.MaxAbsScaler`
- :class:`preprocessing.MinMaxScaler`
- :class:`preprocessing.Normalizer`
- :class:`preprocessing.PolynomialFeatures`
- :class:`mixture.GaussianMixture` (with `init_params="random"` or
`init_params="random_from_data"` and `warm_start=False`)

Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/31580.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- :class:`preprocessing.PolynomialFeatures` now supports array API compatible inputs.
By :user:`Omar Salman <OmarManzoor>`
55 changes: 41 additions & 14 deletions sklearn/preprocessing/_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@
from scipy.interpolate import BSpline
from scipy.special import comb

from sklearn.utils._array_api import (
_is_numpy_namespace,
get_namespace_and_device,
supported_float_dtypes,
)

from ..base import BaseEstimator, TransformerMixin, _fit_context
from ..utils import check_array
from ..utils._mask import _get_mask
Expand Down Expand Up @@ -416,18 +422,18 @@ def transform(self, X):
`csr_matrix`.
"""
check_is_fitted(self)

xp, _, device_ = get_namespace_and_device(X)
X = validate_data(
self,
X,
order="F",
dtype=FLOAT_DTYPES,
dtype=supported_float_dtypes(xp=xp, device=device_),
reset=False,
accept_sparse=("csr", "csc"),
)

n_samples, n_features = X.shape
max_int32 = np.iinfo(np.int32).max
max_int32 = xp.iinfo(xp.int32).max
if sparse.issparse(X) and X.format == "csr":
if self._max_degree > 3:
return self.transform(X.tocsc()).tocsr()
Expand Down Expand Up @@ -497,8 +503,19 @@ def transform(self, X):
else:
# Do as if _min_degree = 0 and cut down array after the
# computation, i.e. use _n_out_full instead of n_output_features_.
XP = np.empty(
shape=(n_samples, self._n_out_full), dtype=X.dtype, order=self.order
order_kwargs = {}
if _is_numpy_namespace(xp=xp):
order_kwargs["order"] = self.order
elif self.order == "F":
raise ValueError(
"PolynomialFeatures does not support order='F' for non-numpy arrays"
)

XP = xp.empty(
shape=(n_samples, self._n_out_full),
dtype=X.dtype,
device=device_,
**order_kwargs,
)

# What follows is a faster implementation of:
Expand Down Expand Up @@ -544,12 +561,18 @@ def transform(self, X):
break
# XP[:, start:end] are terms of degree d - 1
# that exclude feature #feature_idx.
np.multiply(
XP[:, start:end],
X[:, feature_idx : feature_idx + 1],
out=XP[:, current_col:next_col],
casting="no",
)
if _is_numpy_namespace(xp):
# numpy performs this multiplication in place
np.multiply(
XP[:, start:end],
X[:, feature_idx : feature_idx + 1],
out=XP[:, current_col:next_col],
casting="no",
)
else:
XP[:, current_col:next_col] = xp.multiply(
XP[:, start:end], X[:, feature_idx : feature_idx + 1]
)
current_col = next_col

new_index.append(current_col)
Expand All @@ -558,19 +581,23 @@ def transform(self, X):
if self._min_degree > 1:
n_XP, n_Xout = self._n_out_full, self.n_output_features_
if self.include_bias:
Xout = np.empty(
shape=(n_samples, n_Xout), dtype=XP.dtype, order=self.order
Xout = xp.empty(
shape=(n_samples, n_Xout),
dtype=XP.dtype,
device=device_,
**order_kwargs,
)
Xout[:, 0] = 1
Xout[:, 1:] = XP[:, n_XP - n_Xout + 1 :]
else:
Xout = XP[:, n_XP - n_Xout :].copy()
Xout = xp.asarray(XP[:, n_XP - n_Xout :], copy=True)
XP = Xout
return XP

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.sparse = True
tags.array_api_support = True
return tags


Expand Down
71 changes: 71 additions & 0 deletions sklearn/preprocessing/tests/test_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from scipy.interpolate import BSpline
from scipy.sparse import random as sparse_random

from sklearn._config import config_context
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import (
Expand All @@ -18,8 +19,17 @@
from sklearn.preprocessing._csr_polynomial_expansion import (
_get_sizeof_LARGEST_INT_t,
)
from sklearn.utils._array_api import (
_convert_to_numpy,
_get_namespace_device_dtype_ids,
_is_numpy_namespace,
device,
get_namespace,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._mask import _get_mask
from sklearn.utils._testing import (
_array_api_for_tests,
assert_allclose_dense_sparse,
assert_array_almost_equal,
)
Expand Down Expand Up @@ -1336,3 +1346,64 @@ def test_csr_polynomial_expansion_windows_fail(csr_container):
X_trans = pf.fit_transform(X)
for idx in range(3):
assert X_trans[0, expected_indices[idx]] == pytest.approx(1.0)


@pytest.mark.parametrize(
"array_namespace, device_, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize("interaction_only", [True, False])
@pytest.mark.parametrize("include_bias", [True, False])
@pytest.mark.parametrize("degree", [2, (2, 2), 3, (3, 3)])
def test_polynomial_features_array_api_compliance(
two_features_degree3,
degree,
include_bias,
interaction_only,
array_namespace,
device_,
dtype_name,
):
"""Test array API compliance for PolynomialFeatures on 2 features up to degree 3."""
xp = _array_api_for_tests(array_namespace, device_)
X, _ = two_features_degree3
X_np = X.astype(dtype_name)
X_xp = xp.asarray(X_np, device=device_)
with config_context(array_api_dispatch=True):
tf_np = PolynomialFeatures(
degree=degree, include_bias=include_bias, interaction_only=interaction_only
).fit(X_np)

tf_xp = PolynomialFeatures(
degree=degree, include_bias=include_bias, interaction_only=interaction_only
).fit(X_xp)
out_np = tf_np.transform(X_np)
out_xp = tf_xp.transform(X_xp)
assert_allclose(_convert_to_numpy(out_xp, xp=xp), out_np)
assert get_namespace(out_xp)[0].__name__ == xp.__name__
assert device(out_xp) == device(X_xp)
assert out_xp.dtype == X_xp.dtype


@pytest.mark.parametrize(
"array_namespace, device_, dtype_name",
yield_namespace_device_dtype_combinations(),
ids=_get_namespace_device_dtype_ids,
)
def test_polynomial_features_array_api_raises_on_order_F(
array_namespace, device_, dtype_name
):
"""Test that PolynomialFeatures with order='F' raises ValueError on
array API namespaces other than numpy."""
xp = _array_api_for_tests(array_namespace, device_)
X = np.arange(6).reshape((3, 2)).astype(dtype_name)
X_xp = xp.asarray(X, device=device_)
msg = "PolynomialFeatures does not support order='F' for non-numpy arrays"
with config_context(array_api_dispatch=True):
pf = PolynomialFeatures(order="F").fit(X_xp)
if _is_numpy_namespace(xp): # Numpy should not raise
pf.transform(X_xp)
else:
with pytest.raises(ValueError, match=msg):
pf.transform(X_xp)
16 changes: 12 additions & 4 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def _isdtype_single(dtype, kind, *, xp):
return dtype == kind


def supported_float_dtypes(xp):
def supported_float_dtypes(xp, device=None):
"""Supported floating point types for the namespace.

Note: float16 is not officially part of the Array API spec at the
Expand All @@ -299,10 +299,18 @@ def supported_float_dtypes(xp):

https://data-apis.org/array-api/latest/API_specification/data_types.html
"""
dtypes_dict = xp.__array_namespace_info__().dtypes(
kind="real floating", device=device
)
valid_float_dtypes = []
for dtype_key in ("float64", "float32"):
if dtype_key in dtypes_dict:
valid_float_dtypes.append(dtypes_dict[dtype_key])

if hasattr(xp, "float16"):
return (xp.float64, xp.float32, xp.float16)
else:
return (xp.float64, xp.float32)
valid_float_dtypes.append(xp.float16)

return tuple(valid_float_dtypes)


def ensure_common_namespace_device(reference, *arrays):
Expand Down
18 changes: 18 additions & 0 deletions sklearn/utils/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
get_namespace_and_device,
indexing_dtype,
np_compat,
supported_float_dtypes,
yield_namespace_device_dtype_combinations,
)
from sklearn.utils._testing import (
Expand Down Expand Up @@ -777,3 +778,20 @@ def test_logsumexp_like_scipy_logsumexp(array_namespace, device_, dtype_name, ax
res_xp_2 = _logsumexp(array_xp_2, axis=axis)
res_xp_2 = _convert_to_numpy(res_xp_2, xp)
assert_allclose(res_np_2, res_xp_2, rtol=rtol)


@pytest.mark.parametrize(
("namespace", "device_", "expected_types"),
[
("numpy", None, ("float64", "float32", "float16")),
("array_api_strict", None, ("float64", "float32")),
("torch", "cpu", ("float64", "float32", "float16")),
("torch", "cuda", ("float64", "float32", "float16")),
("torch", "mps", ("float32", "float16")),
],
)
def test_supported_float_types(namespace, device_, expected_types):
xp = _array_api_for_tests(namespace, device_)
float_types = supported_float_dtypes(xp, device=device_)
expected = tuple(getattr(xp, dtype_name) for dtype_name in expected_types)
assert float_types == expected