Skip to content

Commit 3cd8062

Browse files
authored
MNT Prune unused argument in _array_api_for_tests util (scikit-learn#27941)
1 parent 8e10cd7 commit 3cd8062

File tree

7 files changed

+8
-8
lines changed

7 files changed

+8
-8
lines changed

sklearn/decomposition/tests/test_pca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ def test_variance_correctness(copy):
818818

819819

820820
def check_array_api_get_precision(name, estimator, array_namespace, device, dtype):
821-
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype)
821+
xp = _array_api_for_tests(array_namespace, device)
822822
iris_np = iris.data.astype(dtype)
823823
iris_xp = xp.asarray(iris_np, device=device)
824824

sklearn/metrics/tests/test_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1735,7 +1735,7 @@ def test_metrics_pos_label_error_str(metric, y_pred_threshold, dtype_y_str):
17351735
def check_array_api_metric(
17361736
metric, array_namespace, device, dtype, y_true_np, y_pred_np, sample_weight=None
17371737
):
1738-
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype)
1738+
xp = _array_api_for_tests(array_namespace, device)
17391739
y_true_xp = xp.asarray(y_true_np, device=device)
17401740
y_pred_xp = xp.asarray(y_pred_np, device=device)
17411741

sklearn/model_selection/tests/test_split.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1279,7 +1279,7 @@ def test_train_test_split_default_test_size(train_size, exp_train, exp_test):
12791279
),
12801280
)
12811281
def test_array_api_train_test_split(shuffle, stratify, array_namespace, device, dtype):
1282-
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype)
1282+
xp = _array_api_for_tests(array_namespace, device)
12831283

12841284
X = np.arange(100).reshape((10, 10))
12851285
y = np.arange(10)

sklearn/utils/_testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,7 +1030,7 @@ def fit_transform(self, X, y=None):
10301030
return self.fit(X, y).transform(X, y)
10311031

10321032

1033-
def _array_api_for_tests(array_namespace, device, dtype):
1033+
def _array_api_for_tests(array_namespace, device):
10341034
try:
10351035
if array_namespace == "numpy.array_api":
10361036
# FIXME: once it is not experimental anymore
@@ -1079,4 +1079,4 @@ def _array_api_for_tests(array_namespace, device, dtype):
10791079

10801080
if cupy.cuda.runtime.getDeviceCount() == 0:
10811081
raise SkipTest("CuPy test requires cuda, which is not available")
1082-
return xp, device, dtype
1082+
return xp

sklearn/utils/estimator_checks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def check_array_api_input(
875875
When check_values is True, it also checks that calling the estimator on the
876876
array_api Array gives the same results as ndarrays.
877877
"""
878-
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype)
878+
xp = _array_api_for_tests(array_namespace, device)
879879

880880
X, y = make_classification(random_state=42)
881881
X = X.astype(dtype, copy=False)

sklearn/utils/tests/test_array_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def test_asarray_with_order_ignored():
145145
def test_weighted_sum(
146146
array_namespace, device, dtype, sample_weight, normalize, expected
147147
):
148-
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype)
148+
xp = _array_api_for_tests(array_namespace, device)
149149
sample_score = numpy.asarray([1, 2, 3, 4], dtype=dtype)
150150
sample_score = xp.asarray(sample_score, device=device)
151151
if sample_weight is not None:

sklearn/utils/tests/test_multiclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def test_is_multilabel():
383383
yield_namespace_device_dtype_combinations(),
384384
)
385385
def test_is_multilabel_array_api_compliance(array_namespace, device, dtype):
386-
xp, device, dtype = _array_api_for_tests(array_namespace, device, dtype)
386+
xp = _array_api_for_tests(array_namespace, device)
387387

388388
for group, group_examples in ARRAY_API_EXAMPLES.items():
389389
dense_exp = group == "multilabel-indicator"

0 commit comments

Comments
 (0)