Skip to content

Commit 48cba5a

Browse files
AlexanderFabischEdAbatiogriselCharles HillOmarManzoor
authored
FEA Make standard scaler compatible to Array API (#27113)
Co-authored-by: Edoardo Abati <29585319+EdAbati@users.noreply.github.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org> Co-authored-by: Charles Hill <charles.hill@etegent.com> Co-authored-by: Omar Salman <omar.salman@arbisoft.com> Co-authored-by: Loïc Estève <loic.esteve@ymail.com>
1 parent 872be3c commit 48cba5a

File tree

9 files changed

+318
-54
lines changed

9 files changed

+318
-54
lines changed

doc/modules/array_api.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ Estimators
123123
- :class:`preprocessing.MinMaxScaler`
124124
- :class:`preprocessing.Normalizer`
125125
- :class:`preprocessing.PolynomialFeatures`
126+
- :class:`preprocessing.StandardScaler` (see :ref:`device_support_for_float64`)
126127
- :class:`mixture.GaussianMixture` (with `init_params="random"` or
127128
`init_params="random_from_data"` and `warm_start=False`)
128129

@@ -329,7 +330,8 @@ Note on device support for ``float64``
329330

330331
Certain operations within scikit-learn will automatically perform operations
331332
on floating-point values with `float64` precision to prevent overflows and ensure
332-
correctness (e.g., :func:`metrics.pairwise.euclidean_distances`). However,
333+
correctness (e.g., :func:`metrics.pairwise.euclidean_distances`,
334+
:class:`preprocessing.StandardScaler`). However,
333335
certain combinations of array namespaces and devices, such as `PyTorch on MPS`
334336
(see :ref:`mps_support`) do not support the `float64` data type. In these cases,
335337
scikit-learn will revert to using the `float32` data type instead. This can result in
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
- :class:`sklearn.preprocessing.StandardScaler` now supports Array API compliant inputs.
2+
:pr:`27113` by :user:`Alexander Fabisch <AlexanderFabisch>`, :user:`Edoardo Abati <EdAbati>`,
3+
:user:`Olivier Grisel <ogrisel>` and :user:`Charles Hill <charlesjhill>`.

sklearn/preprocessing/_data.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
from sklearn.utils import _array_api, check_array, metadata_routing, resample
2121
from sklearn.utils._array_api import (
2222
_find_matching_floating_dtype,
23+
_max_precision_float_dtype,
2324
_modify_in_place_if_numpy,
2425
device,
2526
get_namespace,
2627
get_namespace_and_device,
28+
size,
29+
supported_float_dtypes,
2730
)
2831
from sklearn.utils._param_validation import (
2932
Interval,
@@ -86,7 +89,9 @@ def _is_constant_feature(var, mean, n_samples):
8689
recommendations", by Chan, Golub, and LeVeque.
8790
"""
8891
# In scikit-learn, variance is always computed using float64 accumulators.
89-
eps = np.finfo(np.float64).eps
92+
xp, _, device_ = get_namespace_and_device(var, mean)
93+
max_float_dtype = _max_precision_float_dtype(xp=xp, device=device_)
94+
eps = xp.finfo(max_float_dtype).eps
9095

9196
upper_bound = n_samples * eps * var + (n_samples * mean * eps) ** 2
9297
return var <= upper_bound
@@ -952,12 +957,13 @@ def partial_fit(self, X, y=None, sample_weight=None):
952957
self : object
953958
Fitted scaler.
954959
"""
960+
xp, _, X_device = get_namespace_and_device(X)
955961
first_call = not hasattr(self, "n_samples_seen_")
956962
X = validate_data(
957963
self,
958964
X,
959965
accept_sparse=("csr", "csc"),
960-
dtype=FLOAT_DTYPES,
966+
dtype=supported_float_dtypes(xp, X_device),
961967
ensure_all_finite="allow-nan",
962968
reset=first_call,
963969
)
@@ -971,14 +977,14 @@ def partial_fit(self, X, y=None, sample_weight=None):
971977
# See incr_mean_variance_axis and _incremental_mean_variance_axis
972978

973979
# if n_samples_seen_ is an integer (i.e. no missing values), we need to
974-
# transform it to a NumPy array of shape (n_features,) required by
980+
# transform it to an array of shape (n_features,) required by
975981
# incr_mean_variance_axis and _incremental_variance_axis
976-
dtype = np.int64 if sample_weight is None else X.dtype
977-
if not hasattr(self, "n_samples_seen_"):
978-
self.n_samples_seen_ = np.zeros(n_features, dtype=dtype)
979-
elif np.size(self.n_samples_seen_) == 1:
980-
self.n_samples_seen_ = np.repeat(self.n_samples_seen_, X.shape[1])
981-
self.n_samples_seen_ = self.n_samples_seen_.astype(dtype, copy=False)
982+
dtype = xp.int64 if sample_weight is None else X.dtype
983+
if first_call:
984+
self.n_samples_seen_ = xp.zeros(n_features, dtype=dtype, device=X_device)
985+
elif size(self.n_samples_seen_) == 1:
986+
self.n_samples_seen_ = xp.repeat(self.n_samples_seen_, X.shape[1])
987+
self.n_samples_seen_ = xp.astype(self.n_samples_seen_, dtype, copy=False)
982988

983989
if sparse.issparse(X):
984990
if self.with_mean:
@@ -1036,7 +1042,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
10361042
if not self.with_mean and not self.with_std:
10371043
self.mean_ = None
10381044
self.var_ = None
1039-
self.n_samples_seen_ += X.shape[0] - np.isnan(X).sum(axis=0)
1045+
self.n_samples_seen_ += X.shape[0] - xp.isnan(X).sum(axis=0)
10401046

10411047
else:
10421048
self.mean_, self.var_, self.n_samples_seen_ = _incremental_mean_and_var(
@@ -1050,7 +1056,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
10501056
# for backward-compatibility, reduce n_samples_seen_ to an integer
10511057
# if the number of samples is the same for each feature (i.e. no
10521058
# missing values)
1053-
if np.ptp(self.n_samples_seen_) == 0:
1059+
if xp.max(self.n_samples_seen_) == xp.min(self.n_samples_seen_):
10541060
self.n_samples_seen_ = self.n_samples_seen_[0]
10551061

10561062
if self.with_std:
@@ -1060,7 +1066,7 @@ def partial_fit(self, X, y=None, sample_weight=None):
10601066
self.var_, self.mean_, self.n_samples_seen_
10611067
)
10621068
self.scale_ = _handle_zeros_in_scale(
1063-
np.sqrt(self.var_), copy=False, constant_mask=constant_mask
1069+
xp.sqrt(self.var_), copy=False, constant_mask=constant_mask
10641070
)
10651071
else:
10661072
self.scale_ = None
@@ -1082,6 +1088,7 @@ def transform(self, X, copy=None):
10821088
X_tr : {ndarray, sparse matrix} of shape (n_samples, n_features)
10831089
Transformed array.
10841090
"""
1091+
xp, _, X_device = get_namespace_and_device(X)
10851092
check_is_fitted(self)
10861093

10871094
copy = copy if copy is not None else self.copy
@@ -1091,7 +1098,7 @@ def transform(self, X, copy=None):
10911098
reset=False,
10921099
accept_sparse="csr",
10931100
copy=copy,
1094-
dtype=FLOAT_DTYPES,
1101+
dtype=supported_float_dtypes(xp, X_device),
10951102
force_writeable=True,
10961103
ensure_all_finite="allow-nan",
10971104
)
@@ -1106,9 +1113,9 @@ def transform(self, X, copy=None):
11061113
inplace_column_scale(X, 1 / self.scale_)
11071114
else:
11081115
if self.with_mean:
1109-
X -= self.mean_
1116+
X -= xp.astype(self.mean_, X.dtype)
11101117
if self.with_std:
1111-
X /= self.scale_
1118+
X /= xp.astype(self.scale_, X.dtype)
11121119
return X
11131120

11141121
def inverse_transform(self, X, copy=None):
@@ -1127,14 +1134,15 @@ def inverse_transform(self, X, copy=None):
11271134
X_original : {ndarray, sparse matrix} of shape (n_samples, n_features)
11281135
Transformed array.
11291136
"""
1137+
xp, _, X_device = get_namespace_and_device(X)
11301138
check_is_fitted(self)
11311139

11321140
copy = copy if copy is not None else self.copy
11331141
X = check_array(
11341142
X,
11351143
accept_sparse="csr",
11361144
copy=copy,
1137-
dtype=FLOAT_DTYPES,
1145+
dtype=supported_float_dtypes(xp, X_device),
11381146
force_writeable=True,
11391147
ensure_all_finite="allow-nan",
11401148
)
@@ -1149,16 +1157,17 @@ def inverse_transform(self, X, copy=None):
11491157
inplace_column_scale(X, self.scale_)
11501158
else:
11511159
if self.with_std:
1152-
X *= self.scale_
1160+
X *= xp.astype(self.scale_, X.dtype)
11531161
if self.with_mean:
1154-
X += self.mean_
1162+
X += xp.astype(self.mean_, X.dtype)
11551163
return X
11561164

11571165
def __sklearn_tags__(self):
11581166
tags = super().__sklearn_tags__()
11591167
tags.input_tags.allow_nan = True
11601168
tags.input_tags.sparse = not self.with_mean
11611169
tags.transformer_tags.preserves_dtype = ["float64", "float32"]
1170+
tags.array_api_support = True
11621171
return tags
11631172

11641173

sklearn/preprocessing/tests/test_data.py

Lines changed: 100 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
_get_namespace_device_dtype_ids,
4444
yield_namespace_device_dtype_combinations,
4545
)
46-
from sklearn.utils._test_common.instance_generator import _get_check_estimator_ids
4746
from sklearn.utils._testing import (
4847
_array_api_for_tests,
4948
_convert_container,
@@ -56,6 +55,7 @@
5655
skip_if_32bit,
5756
)
5857
from sklearn.utils.estimator_checks import (
58+
_get_check_estimator_ids,
5959
check_array_api_input_and_values,
6060
)
6161
from sklearn.utils.fixes import (
@@ -117,10 +117,13 @@ def test_raises_value_error_if_sample_weights_greater_than_1d():
117117
scaler.fit(X, y, sample_weight=sample_weight_notOK)
118118

119119

120-
@pytest.mark.parametrize(
121-
["Xw", "X", "sample_weight"],
122-
[
123-
([[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [1, 2, 3], [4, 5, 6]], [2.0, 1.0]),
120+
def _yield_xw_x_sampleweight():
121+
yield from (
122+
(
123+
[[1, 2, 3], [4, 5, 6]],
124+
[[1, 2, 3], [1, 2, 3], [4, 5, 6]],
125+
[2.0, 1.0],
126+
),
124127
(
125128
[[1, 0, 1], [0, 0, 1]],
126129
[[1, 0, 1], [0, 0, 1], [0, 0, 1], [0, 0, 1]],
@@ -136,8 +139,10 @@ def test_raises_value_error_if_sample_weights_greater_than_1d():
136139
],
137140
np.array([1, 3]),
138141
),
139-
],
140-
)
142+
)
143+
144+
145+
@pytest.mark.parametrize(["Xw", "X", "sample_weight"], _yield_xw_x_sampleweight())
141146
@pytest.mark.parametrize("array_constructor", ["array", "sparse_csr", "sparse_csc"])
142147
def test_standard_scaler_sample_weight(Xw, X, sample_weight, array_constructor):
143148
with_mean = not array_constructor.startswith("sparse")
@@ -161,6 +166,68 @@ def test_standard_scaler_sample_weight(Xw, X, sample_weight, array_constructor):
161166
assert_almost_equal(scaler.transform(X_test), scaler_w.transform(X_test))
162167

163168

169+
@pytest.mark.parametrize(["Xw", "X", "sample_weight"], _yield_xw_x_sampleweight())
170+
@pytest.mark.parametrize(
171+
"namespace, dev, dtype",
172+
yield_namespace_device_dtype_combinations(),
173+
ids=_get_namespace_device_dtype_ids,
174+
)
175+
def test_standard_scaler_sample_weight_array_api(
176+
Xw, X, sample_weight, namespace, dev, dtype
177+
):
178+
# N.B. The sample statistics for Xw w/ sample_weight should match
179+
# the statistics of X w/ uniform sample_weight.
180+
xp = _array_api_for_tests(namespace, dev)
181+
182+
X = np.array(X).astype(dtype, copy=False)
183+
y = np.ones(X.shape[0]).astype(dtype, copy=False)
184+
Xw = np.array(Xw).astype(dtype, copy=False)
185+
yw = np.ones(Xw.shape[0]).astype(dtype, copy=False)
186+
X_test = np.array([[1.5, 2.5, 3.5], [3.5, 4.5, 5.5]]).astype(dtype, copy=False)
187+
188+
scaler = StandardScaler()
189+
scaler.fit(X, y)
190+
191+
scaler_w = StandardScaler()
192+
scaler_w.fit(Xw, yw, sample_weight=sample_weight)
193+
194+
# Test array-api support and correctness.
195+
X_xp = xp.asarray(X, device=dev)
196+
y_xp = xp.asarray(y, device=dev)
197+
Xw_xp = xp.asarray(Xw, device=dev)
198+
yw_xp = xp.asarray(yw, device=dev)
199+
X_test_xp = xp.asarray(X_test, device=dev)
200+
sample_weight_xp = xp.asarray(sample_weight, device=dev)
201+
202+
scaler_w_xp = StandardScaler()
203+
with config_context(array_api_dispatch=True):
204+
scaler_w_xp.fit(Xw_xp, yw_xp, sample_weight=sample_weight_xp)
205+
w_mean = _convert_to_numpy(scaler_w_xp.mean_, xp=xp)
206+
w_var = _convert_to_numpy(scaler_w_xp.var_, xp=xp)
207+
208+
assert_allclose(scaler_w.mean_, w_mean)
209+
assert_allclose(scaler_w.var_, w_var)
210+
211+
# unweighted, but with repeated samples
212+
scaler_xp = StandardScaler()
213+
with config_context(array_api_dispatch=True):
214+
scaler_xp.fit(X_xp, y_xp)
215+
uw_mean = _convert_to_numpy(scaler_xp.mean_, xp=xp)
216+
uw_var = _convert_to_numpy(scaler_xp.var_, xp=xp)
217+
218+
assert_allclose(scaler.mean_, uw_mean)
219+
assert_allclose(scaler.var_, uw_var)
220+
221+
# Check that both array-api outputs match.
222+
assert_allclose(uw_mean, w_mean)
223+
assert_allclose(uw_var, w_var)
224+
with config_context(array_api_dispatch=True):
225+
assert_allclose(
226+
_convert_to_numpy(scaler_xp.transform(X_test_xp), xp=xp),
227+
_convert_to_numpy(scaler_w_xp.transform(X_test_xp), xp=xp),
228+
)
229+
230+
164231
def test_standard_scaler_1d():
165232
# Test scaling of dataset along single axis
166233
for X in [X_1row, X_1col, X_list_1row, X_list_1row]:
@@ -726,6 +793,32 @@ def test_preprocessing_array_api_compliance(
726793
check(name, estimator, array_namespace, device=device, dtype_name=dtype_name)
727794

728795

796+
@pytest.mark.parametrize(
797+
"array_namespace, device, dtype_name",
798+
yield_namespace_device_dtype_combinations(),
799+
ids=_get_namespace_device_dtype_ids,
800+
)
801+
@pytest.mark.parametrize(
802+
"check",
803+
[check_array_api_input_and_values],
804+
ids=_get_check_estimator_ids,
805+
)
806+
@pytest.mark.parametrize("sample_weight", [True, None])
807+
def test_standard_scaler_array_api_compliance(
808+
check, sample_weight, array_namespace, device, dtype_name
809+
):
810+
estimator = StandardScaler()
811+
name = estimator.__class__.__name__
812+
check(
813+
name,
814+
estimator,
815+
array_namespace,
816+
device=device,
817+
dtype_name=dtype_name,
818+
check_sample_weight=sample_weight,
819+
)
820+
821+
729822
def test_min_max_scaler_iris():
730823
X = iris.data
731824
scaler = MinMaxScaler()

sklearn/utils/_array_api.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,34 @@ def _union1d(a, b, xp):
246246
def supported_float_dtypes(xp, device=None):
247247
"""Supported floating point types for the namespace.
248248
249-
Note: float16 is not officially part of the Array API spec at the
249+
Parameters
250+
----------
251+
xp : module
252+
Array namespace to inspect.
253+
254+
device : str or device instance from xp, default=None
255+
Device to use for dtype selection. If ``None``, then a default device
256+
is assumed.
257+
258+
Returns
259+
-------
260+
supported_dtypes : tuple
261+
Tuple of real floating data types supported by the provided array namespace,
262+
ordered from the highest precision to lowest.
263+
264+
See Also
265+
--------
266+
max_precision_float_dtype : Maximum float dtype for a namespace/device pair.
267+
268+
Notes
269+
-----
270+
`float16` is not officially part of the Array API spec at the
250271
time of writing but scikit-learn estimators and functions can choose
251272
to accept it when xp.float16 is defined.
252273
274+
Additionally, some devices available within a namespace may not support
275+
all floating-point types that the namespace provides.
276+
253277
https://data-apis.org/array-api/latest/API_specification/data_types.html
254278
"""
255279
dtypes_dict = xp.__array_namespace_info__().dtypes(
@@ -748,6 +772,19 @@ def _nanmean(X, axis=None, xp=None):
748772
return total / count
749773

750774

775+
def _nansum(X, axis=None, xp=None, keepdims=False, dtype=None):
776+
# TODO: refactor once nan-aware reductions are standardized:
777+
# https://github.com/data-apis/array-api/issues/621
778+
xp, _, X_device = get_namespace_and_device(X, xp=xp)
779+
780+
if _is_numpy_namespace(xp):
781+
return xp.asarray(numpy.nansum(X, axis=axis, keepdims=keepdims, dtype=dtype))
782+
783+
mask = xp.isnan(X)
784+
masked_arr = xp.where(mask, xp.asarray(0, device=X_device, dtype=X.dtype), X)
785+
return xp.sum(masked_arr, axis=axis, keepdims=keepdims, dtype=dtype)
786+
787+
751788
def _asarray_with_order(
752789
array, dtype=None, order=None, copy=None, *, xp=None, device=None
753790
):

0 commit comments

Comments
 (0)