Skip to content

Commit 202b532

Browse files
liamgejnothman
authored andcommitted
MAINT Remove redundancy in #9552 (#9573)
1 parent e41c4d5 commit 202b532

File tree

3 files changed

+26
-8
lines changed

3 files changed

+26
-8
lines changed

sklearn/preprocessing/tests/test_data.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from sklearn.utils import gen_batches
1515

16+
from sklearn.utils.testing import assert_raise_message
1617
from sklearn.utils.testing import assert_almost_equal
1718
from sklearn.utils.testing import clean_warning_registry
1819
from sklearn.utils.testing import assert_array_almost_equal
@@ -932,6 +933,10 @@ def test_quantile_transform_check_error():
932933
assert_raises_regex(ValueError, "'output_distribution' has to be either"
933934
" 'normal' or 'uniform'. Got 'rnd' instead.",
934935
transformer.inverse_transform, X_tran)
936+
# check that an error is raised if input is scalar
937+
assert_raise_message(ValueError,
938+
'Expected 2D array, got scalar array instead',
939+
transformer.transform, 10)
935940

936941

937942
def test_quantile_transform_sparse_ignore_zeros():
@@ -1157,14 +1162,16 @@ def test_quantile_transform_bounds():
11571162
X = np.random.random((1000, 1))
11581163
transformer = QuantileTransformer()
11591164
transformer.fit(X)
1160-
assert_equal(transformer.transform(-10), transformer.transform(np.min(X)))
1161-
assert_equal(transformer.transform(10), transformer.transform(np.max(X)))
1162-
assert_equal(transformer.inverse_transform(-10),
1165+
assert_equal(transformer.transform([[-10]]),
1166+
transformer.transform([[np.min(X)]]))
1167+
assert_equal(transformer.transform([[10]]),
1168+
transformer.transform([[np.max(X)]]))
1169+
assert_equal(transformer.inverse_transform([[-10]]),
11631170
transformer.inverse_transform(
1164-
np.min(transformer.references_)))
1165-
assert_equal(transformer.inverse_transform(10),
1171+
[[np.min(transformer.references_)]]))
1172+
assert_equal(transformer.inverse_transform([[10]]),
11661173
transformer.inverse_transform(
1167-
np.max(transformer.references_)))
1174+
[[np.max(transformer.references_)]]))
11681175

11691176

11701177
def test_quantile_transform_and_inverse():

sklearn/utils/tests/test_validation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,13 @@ def test_check_array():
142142
# ensure_2d=False
143143
X_array = check_array([0, 1, 2], ensure_2d=False)
144144
assert_equal(X_array.ndim, 1)
145-
# ensure_2d=True
145+
# ensure_2d=True with 1d array
146146
assert_raise_message(ValueError, 'Expected 2D array, got 1D array instead',
147147
check_array, [0, 1, 2], ensure_2d=True)
148+
# ensure_2d=True with scalar array
149+
assert_raise_message(ValueError,
150+
'Expected 2D array, got scalar array instead',
151+
check_array, 10, ensure_2d=True)
148152
# don't allow ndim > 3
149153
X_ndim = np.arange(8).reshape(2, 2, 2)
150154
assert_raises(ValueError, check_array, X_ndim)

sklearn/utils/validation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,13 +459,20 @@ def check_array(array, accept_sparse=False, dtype="numeric", order=None,
459459
_ensure_no_complex_data(array)
460460

461461
if ensure_2d:
462+
# If input is scalar raise error
463+
if array.ndim == 0:
464+
raise ValueError(
465+
"Expected 2D array, got scalar array instead:\narray={}.\n"
466+
"Reshape your data either using array.reshape(-1, 1) if "
467+
"your data has a single feature or array.reshape(1, -1) "
468+
"if it contains a single sample.".format(array))
469+
# If input is 1D raise error
462470
if array.ndim == 1:
463471
raise ValueError(
464472
"Expected 2D array, got 1D array instead:\narray={}.\n"
465473
"Reshape your data either using array.reshape(-1, 1) if "
466474
"your data has a single feature or array.reshape(1, -1) "
467475
"if it contains a single sample.".format(array))
468-
array = np.atleast_2d(array)
469476
# To ensure that array flags are maintained
470477
array = np.array(array, dtype=dtype, order=order, copy=copy)
471478

0 commit comments

Comments
 (0)