From c01385c052a010f138707ad41e593414cfe76133 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 25 Jul 2019 18:20:00 +0200 Subject: [PATCH 01/45] EHN add support for scalar, slice and mask in safe_indexing axis=0 --- doc/whats_new/v0.22.rst | 4 ++ sklearn/utils/__init__.py | 21 +++++++++-- sklearn/utils/tests/test_utils.py | 62 ++++++++++++++++++++++++++----- 3 files changed, 73 insertions(+), 14 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index bcef08ff1881b..d1d838ec99dc5 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -206,6 +206,10 @@ Changelog NumPy array, SciPy sparse matrix, and Pandas DataFrame. :pr:`14035` by `Guillaume Lemaitre `. +- |Enhancement| :func:`utils.safe_indexing` becomes more consistent and support + scalar, slice, mask indexing for both `axis=0` and `axis=1`. + :pr:`xx` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.neural_network` ............................. diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index efcaf6865faa5..2b2b33b0186d4 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1,6 +1,7 @@ """ The :mod:`sklearn.utils` module includes various utilities. """ +from collections.abc import Iterable from collections.abc import Sequence from contextlib import contextmanager from itertools import islice @@ -188,6 +189,10 @@ def safe_indexing(X, indices, axis=0): Data from which to sample rows, items or columns. indices : array-like - When ``axis=0``, indices need to be an array of integer. + - container: lists, slices, boolean masks: output is 2D. + Supported data types for containers: + - integer or boolean (positional): supported for arrays, sparse + and dataframes - When ``axis=1``, indices can be one of: - scalar: output is 1D, unless `X` is sparse. Supported data types for scalars: @@ -247,10 +252,12 @@ def _safe_indexing_row(X, indices): CSR, CSC, and LIL sparse matrices are supported. COO sparse matrices are not supported. """ - if hasattr(X, "iloc"): - # Work-around for indexing with read-only indices in pandas + if not isinstance(indices, slice): indices = np.asarray(indices) - indices = indices if indices.flags.writeable else indices.copy() + if hasattr(X, "iloc"): + if not isinstance(indices, slice): + # Work-around for indexing with read-only indices in pandas + indices = indices if indices.flags.writeable else indices.copy() # Pandas Dataframes and Series try: return X.iloc[indices] @@ -268,7 +275,13 @@ def _safe_indexing_row(X, indices): else: return X[indices] else: - return [X[idx] for idx in indices] + # In the case of a slice or a scalar + if not isinstance(indices, Iterable) or indices.ndim == 0: + return X[indices] + else: + if np.issubdtype(indices.dtype, np.bool_): + indices = np.flatnonzero(indices) + return [X[idx] for idx in indices] def _check_key_type(key, superclass): diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index a39e8160047a5..1601002449924 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -10,6 +10,7 @@ from sklearn.utils.testing import (assert_raises, assert_array_equal, + assert_allclose_dense_sparse, assert_raises_regex, assert_warns_message, assert_no_warnings) from sklearn.utils import check_random_state @@ -219,14 +220,51 @@ def test_check_key_type(key, clazz, is_expected_type): assert _check_key_type(key, clazz) is is_expected_type -@pytest.mark.parametrize("asarray", [True, False], ids=["array-like", "array"]) -def test_safe_indexing_axis_0(asarray): +@pytest.mark.parametrize( + "idx", + [[0, 2], [True, False, True], # array-like + np.array([0, 2]), np.array([True, False, True])], # numpy array + ids=['list-indices', 'list-mask', 'array-indices', 'array-mask'] +) +@pytest.mark.parametrize( + "array_type", [None, np.asarray, sp.csr_matrix], + ids=["list", "array", "sparse"] +) +def test_safe_indexing_axis_0_container(idx, array_type): X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - inds = np.array([1, 2]) if asarray else [1, 2] - X_inds = safe_indexing(X, inds) - X_arrays = safe_indexing(np.array(X), inds) - assert_array_equal(np.array(X_inds), X_arrays) - assert_array_equal(np.array(X_inds), np.array(X)[inds]) + X = array_type(X) if array_type is not None else X + X_subset = safe_indexing(X, idx, axis=0) + X_expect = [[1, 2, 3], [7, 8, 9]] + X_expect = array_type(X_expect) if array_type is not None else X_expect + assert_allclose_dense_sparse(X_subset, X_expect) + + +@pytest.mark.parametrize( + "array_type", [None, np.asarray, sp.csr_matrix], + ids=["list", "array", "sparse"] +) +def test_safe_indexing_axis_0_slice(array_type): + X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + X = array_type(X) if array_type is not None else X + idx = slice(0, 2) + X_subset = safe_indexing(X, idx, axis=0) + X_expect = [[1, 2, 3], [4, 5, 6]] + X_expect = array_type(X_expect) if array_type is not None else X_expect + assert_allclose_dense_sparse(X_subset, X_expect) + + +@pytest.mark.parametrize( + "array_type", [None, np.asarray, sp.csr_matrix], + ids=["list", "array", "sparse"] +) +def test_safe_indexing_axis_0_scalar(array_type): + X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + X = array_type(X) if array_type is not None else X + idx = 1 # scalar indexing + X_subset = safe_indexing(X, idx, axis=0) + X_expect = [4, 5, 6] + X_expect = array_type(X_expect) if array_type is not None else X_expect + assert_allclose_dense_sparse(X_subset, X_expect) @pytest.mark.parametrize("idx", [0, [0, 1]], ids=['scalar', 'list']) @@ -355,11 +393,15 @@ def test_safe_indexing_pandas_series(idx, asarray): assert_array_equal(safe_indexing(serie, idx).values, [0, 1]) -@pytest.mark.parametrize("asarray", [True, False], ids=["array-like", "array"]) -def test_safe_indexing_mock_pandas(asarray): +@pytest.mark.parametrize( + "inds", + [[1, 2], [False, True, True], + np.array([1, 2]), np.array([False, True, True]), + slice(1, None)] +) +def test_safe_indexing_mock_pandas(inds): X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) X_df = MockDataFrame(X) - inds = np.array([1, 2]) if asarray else [1, 2] X_df_indexed = safe_indexing(X_df, inds) X_indexed = safe_indexing(X_df, inds) assert_array_equal(np.array(X_df_indexed), X_indexed) From 0e5c03738557804fcccc9a4c1746a38714529139 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 25 Jul 2019 18:29:37 +0200 Subject: [PATCH 02/45] DOC --- doc/whats_new/v0.22.rst | 2 +- sklearn/utils/__init__.py | 17 ++++++----------- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index d1d838ec99dc5..716a500ebe085 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -208,7 +208,7 @@ Changelog - |Enhancement| :func:`utils.safe_indexing` becomes more consistent and support scalar, slice, mask indexing for both `axis=0` and `axis=1`. - :pr:`xx` by :user:`Guillaume Lemaitre `. + :pr:`14475` by :user:`Guillaume Lemaitre `. :mod:`sklearn.neural_network` ............................. diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 2b2b33b0186d4..b1bb933efd140 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -188,23 +188,18 @@ def safe_indexing(X, indices, axis=0): X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series Data from which to sample rows, items or columns. indices : array-like - - When ``axis=0``, indices need to be an array of integer. - - container: lists, slices, boolean masks: output is 2D. - Supported data types for containers: - - integer or boolean (positional): supported for arrays, sparse - and dataframes - - When ``axis=1``, indices can be one of: + - For both `axis=0` and `axis=1`, indices can be one of: - scalar: output is 1D, unless `X` is sparse. Supported data types for scalars: - integer: supported for arrays, sparse matrices and dataframes. - - string (key-based): only supported for dataframes. - container: lists, slices, boolean masks: output is 2D. Supported data types for containers: - - integer or boolean (positional): supported for - arrays, sparse matrices and dataframes - - string (key-based): only supported for dataframes. No keys - other than strings are allowed. + - integer or boolean (positional): supported for arrays, sparse + and dataframes + - If `X` is a dataframe and `axis=1`, indices support string data type + (key-based) as a scalar or a container. The output dimension will be + identical to the above case. axis : int, default=0 The axis along which `X` will be subsampled. ``axis=0`` will select rows while ``axis=1`` will select columns. From f5e08c4272fe1633f490ec15ff5381c5f5b3caef Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 25 Jul 2019 22:20:08 +0200 Subject: [PATCH 03/45] FIX behaviour when passing None --- sklearn/utils/__init__.py | 4 +++- sklearn/utils/tests/test_utils.py | 17 +++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index b1bb933efd140..e0f5be6dc35c0 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -247,7 +247,9 @@ def _safe_indexing_row(X, indices): CSR, CSC, and LIL sparse matrices are supported. COO sparse matrices are not supported. """ - if not isinstance(indices, slice): + if indices is None: + return X + elif not isinstance(indices, slice): indices = np.asarray(indices) if hasattr(X, "iloc"): if not isinstance(indices, slice): diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 1601002449924..737a20221b346 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -267,6 +267,17 @@ def test_safe_indexing_axis_0_scalar(array_type): assert_allclose_dense_sparse(X_subset, X_expect) +@pytest.mark.parametrize( + "array_type", [None, np.asarray, sp.csr_matrix], + ids=["list", "array", "sparse"] +) +def test_safe_indexing_axis_0_None(array_type): + X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + X = array_type(X) if array_type is not None else X + X_subset = safe_indexing(X, None, axis=0) + assert_allclose_dense_sparse(X_subset, X) + + @pytest.mark.parametrize("idx", [0, [0, 1]], ids=['scalar', 'list']) @pytest.mark.parametrize("asarray", [True, False], ids=["array-like", "array"]) def test_safe_indexing_axis_1_sparse(idx, asarray): @@ -395,9 +406,11 @@ def test_safe_indexing_pandas_series(idx, asarray): @pytest.mark.parametrize( "inds", - [[1, 2], [False, True, True], + [None, + [1, 2], [False, True, True], np.array([1, 2]), np.array([False, True, True]), - slice(1, None)] + slice(1, None)], + ids=['None', 'list-int', 'list-mask', 'array-int', 'array-mask', 'slice'] ) def test_safe_indexing_mock_pandas(inds): X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) From bb4db91755cf724a0e856ef9dab8bcccdecca686 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 25 Jul 2019 22:21:53 +0200 Subject: [PATCH 04/45] PEP8 --- sklearn/utils/tests/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 737a20221b346..468f7dae4e2e8 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -410,7 +410,7 @@ def test_safe_indexing_pandas_series(idx, asarray): [1, 2], [False, True, True], np.array([1, 2]), np.array([False, True, True]), slice(1, None)], - ids=['None', 'list-int', 'list-mask', 'array-int', 'array-mask', 'slice'] + ids=['None', 'list-int', 'list-mask', 'array-int', 'array-mask', 'slice'] ) def test_safe_indexing_mock_pandas(inds): X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) From 8cd74db2fba750c17c6b4dfe066e411dd30c43b6 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Jul 2019 12:05:25 +0200 Subject: [PATCH 05/45] address thomas comments --- sklearn/utils/__init__.py | 2 +- sklearn/utils/tests/test_utils.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index b1bb933efd140..e06cf0dd2838b 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -196,7 +196,7 @@ def safe_indexing(X, indices, axis=0): - container: lists, slices, boolean masks: output is 2D. Supported data types for containers: - integer or boolean (positional): supported for arrays, sparse - and dataframes + matrices and dataframes - If `X` is a dataframe and `axis=1`, indices support string data type (key-based) as a scalar or a container. The output dimension will be identical to the above case. diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 1601002449924..897a95542e33d 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -240,12 +240,12 @@ def test_safe_indexing_axis_0_container(idx, array_type): @pytest.mark.parametrize( - "array_type", [None, np.asarray, sp.csr_matrix], + "array_type", [list, np.asarray, sp.csr_matrix], ids=["list", "array", "sparse"] ) def test_safe_indexing_axis_0_slice(array_type): X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - X = array_type(X) if array_type is not None else X + X = array_type(X) idx = slice(0, 2) X_subset = safe_indexing(X, idx, axis=0) X_expect = [[1, 2, 3], [4, 5, 6]] @@ -254,12 +254,12 @@ def test_safe_indexing_axis_0_slice(array_type): @pytest.mark.parametrize( - "array_type", [None, np.asarray, sp.csr_matrix], + "array_type", [list, np.asarray, sp.csr_matrix], ids=["list", "array", "sparse"] ) def test_safe_indexing_axis_0_scalar(array_type): X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - X = array_type(X) if array_type is not None else X + X = array_type(X) idx = 1 # scalar indexing X_subset = safe_indexing(X, idx, axis=0) X_expect = [4, 5, 6] From d0f8d60a8654747f2f5600e3fae9c071c1c417aa Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Jul 2019 18:31:25 +0200 Subject: [PATCH 06/45] FIX change boolean array-likes indexing in old NumPy version --- sklearn/compose/tests/test_column_transformer.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index ae7ef31d6c7f1..2ccfd6d6c2eae 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -16,6 +16,7 @@ from sklearn.base import BaseEstimator from sklearn.compose import ColumnTransformer, make_column_transformer from sklearn.exceptions import NotFittedError +from sklearn.preprocessing import FunctionTransformer from sklearn.preprocessing import StandardScaler, Normalizer, OneHotEncoder from sklearn.feature_extraction import DictVectorizer @@ -1108,3 +1109,14 @@ def test_column_transformer_reordered_column_names_remainder(explicit_colname): err_msg = 'Specifying the columns' with pytest.raises(ValueError, match=err_msg): tf.transform(X_array) + + +def test_column_transformer_mask_indexing(): + # Regression test for #xxxxx + # Boolean mask indexing with NumPy < 1.13 + X = np.transpose([[1, 2, 3], [4, 5, 6], [5, 6, 7], [8, 9, 10]]) + column_transformer = ColumnTransformer( + [('identity', FunctionTransformer(), [False, True, False, True])] + ) + X_trans = column_transformer.fit_transform(X) + assert X_trans.shape == (3, 2) From f95a228e5444d801b6ab09d30dcc655d347cc663 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 29 Jul 2019 19:12:10 +0200 Subject: [PATCH 07/45] change indexing --- sklearn/utils/__init__.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index efcaf6865faa5..8db41bb27986d 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -18,6 +18,7 @@ from . import _joblib from ..exceptions import DataConversionWarning from .deprecation import deprecated +from .fixes import np_version from .validation import (as_float_array, assert_all_finite, check_random_state, column_or_1d, check_array, @@ -225,6 +226,17 @@ def safe_indexing(X, indices, axis=0): ) +# FIXME: to be removed once NumPy 1.13 is the minimum version required +def _array_indexing(array, key, axis=0): + """Index an array consistently across NumPy version.""" + if np_version < (1, 13): + # check if we have an boolean array-likes to make the proper indexing + key_array = np.asarray(key) + if np.issubdtype(key_array.dtype, np.bool_): + key = key_array + return array[key] if axis == 0 else array[:, key] + + def _safe_indexing_row(X, indices): """Return items or rows from X using indices. @@ -266,7 +278,7 @@ def _safe_indexing_row(X, indices): # This is often substantially faster than X[indices] return X.take(indices, axis=0) else: - return X[indices] + return _array_indexing(X, indices, axis=0) else: return [X[idx] for idx in indices] @@ -356,7 +368,7 @@ def _safe_indexing_column(X, key): return X.iloc[:, key] else: # numpy arrays, sparse arrays - return X[:, key] + return _array_indexing(X, key, axis=1) def _get_column_indices(X, key): @@ -371,7 +383,7 @@ def _get_column_indices(X, key): or hasattr(key, 'dtype') and np.issubdtype(key.dtype, np.bool_)): # Convert key into positive indexes try: - idx = np.arange(n_columns)[key] + idx = safe_indexing(np.arange(n_columns), key) except IndexError as e: raise ValueError( 'all features must be in [0, %d]' % (n_columns - 1) From 1c8180390799d22bf42b5c1673caf0cb3dd71c79 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 30 Jul 2019 09:54:22 +0200 Subject: [PATCH 08/45] add regression test in utils --- doc/whats_new/v0.22.rst | 8 +++++++ .../compose/tests/test_column_transformer.py | 4 ++-- sklearn/utils/tests/test_utils.py | 22 +++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index 93635d88069d5..114afb9185a18 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -61,6 +61,14 @@ Changelog `sample_weights` are not supported by the wrapped estimator). :pr:`13575` by :user:`William de Vazelhes `. +:mod:`sklearn.compose` +...................... + +- |Fix| Fixed a bug in :class:`compose.ColumnTransformer` which failed to + select the proper columns when using a boolean list and NumPy older than + 1.13. + :pr:`14510` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.datasets` ....................... diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index 2ccfd6d6c2eae..a9c4fd9e25fbe 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -1112,8 +1112,8 @@ def test_column_transformer_reordered_column_names_remainder(explicit_colname): def test_column_transformer_mask_indexing(): - # Regression test for #xxxxx - # Boolean mask indexing with NumPy < 1.13 + # Regression test for #14510 + # Boolean array-like does not behave as boolean array with NumPy < 1.13 X = np.transpose([[1, 2, 3], [4, 5, 6], [5, 6, 7], [8, 9, 10]]) column_transformer = ColumnTransformer( [('identity', FunctionTransformer(), [False, True, False, True])] diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index a39e8160047a5..35cfde4aaef7d 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -10,6 +10,7 @@ from sklearn.utils.testing import (assert_raises, assert_array_equal, + assert_allclose_dense_sparse, assert_raises_regex, assert_warns_message, assert_no_warnings) from sklearn.utils import check_random_state @@ -365,6 +366,27 @@ def test_safe_indexing_mock_pandas(asarray): assert_array_equal(np.array(X_df_indexed), X_indexed) +@pytest.mark.parametrize("array_type", ['array', 'sparse', 'dataframe']) +def test_safe_indexing_mask_axis_1(array_type): + # regression test for #14510 + # check that boolean array-like and boolean array lead to the same indexing + # even in NumPy < 1.13 + if array_type == 'array': + array_constructor = np.asarray + elif array_type == 'sparse': + array_constructor = sp.csr_matrix + elif array_type == 'dataframe': + pd = pytest.importorskip('pandas') + array_constructor = pd.DataFrame + + X = array_constructor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + mask = [True, False, True] + mask_array = np.array(mask) + X_masked = safe_indexing(X, mask, axis=1) + X_masked_array = safe_indexing(X, mask_array, axis=1) + assert_allclose_dense_sparse(X_masked, X_masked_array) + + def test_shuffle_on_ndim_equals_three(): def to_tuple(A): # to make the inner arrays hashable return tuple(tuple(tuple(C) for C in B) for B in A) From c8009a28aa2855e8e01cfebd3ef5df337c0536f4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 30 Jul 2019 12:17:43 +0200 Subject: [PATCH 09/45] fix --- sklearn/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 8db41bb27986d..ac6446afcd6cf 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -229,7 +229,7 @@ def safe_indexing(X, indices, axis=0): # FIXME: to be removed once NumPy 1.13 is the minimum version required def _array_indexing(array, key, axis=0): """Index an array consistently across NumPy version.""" - if np_version < (1, 13): + if np_version < (1, 13) or issparse(array): # check if we have an boolean array-likes to make the proper indexing key_array = np.asarray(key) if np.issubdtype(key_array.dtype, np.bool_): From a80b33d7a60d67706150d48f71b3d0837fceba38 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 30 Jul 2019 14:25:31 +0200 Subject: [PATCH 10/45] add test in column transformer --- sklearn/compose/tests/test_column_transformer.py | 5 ++++- sklearn/utils/__init__.py | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index a9c4fd9e25fbe..a667b35cf65e3 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -1111,10 +1111,13 @@ def test_column_transformer_reordered_column_names_remainder(explicit_colname): tf.transform(X_array) -def test_column_transformer_mask_indexing(): +@pytest.mark.parametrize("array_type", [np.asarray, sparse.csr_matrix]) +def test_column_transformer_mask_indexing(array_type): # Regression test for #14510 # Boolean array-like does not behave as boolean array with NumPy < 1.13 + # and sparse matrices as well X = np.transpose([[1, 2, 3], [4, 5, 6], [5, 6, 7], [8, 9, 10]]) + X = array_type(X) column_transformer = ColumnTransformer( [('identity', FunctionTransformer(), [False, True, False, True])] ) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index ac6446afcd6cf..83f4d7fd1876c 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -226,7 +226,6 @@ def safe_indexing(X, indices, axis=0): ) -# FIXME: to be removed once NumPy 1.13 is the minimum version required def _array_indexing(array, key, axis=0): """Index an array consistently across NumPy version.""" if np_version < (1, 13) or issparse(array): From 9fb045dcf1b7923bf06021b6944ca0cb3dd8ad40 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Aug 2019 14:22:46 +0200 Subject: [PATCH 11/45] raise error if axis not 0 or 1 --- sklearn/utils/__init__.py | 5 +++++ sklearn/utils/tests/test_utils.py | 8 ++++++++ 2 files changed, 13 insertions(+) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 83f4d7fd1876c..3b4a20d08716b 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -228,6 +228,11 @@ def safe_indexing(X, indices, axis=0): def _array_indexing(array, key, axis=0): """Index an array consistently across NumPy version.""" + if axis not in (0, 1): + raise ValueError( + "'axis' should be either 0 (to index rows) or 1 (to index " + " column). Got {} instead.".format(axis) + ) if np_version < (1, 13) or issparse(array): # check if we have an boolean array-likes to make the proper indexing key_array = np.asarray(key) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 35cfde4aaef7d..49f50eedc0a42 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -13,6 +13,7 @@ assert_allclose_dense_sparse, assert_raises_regex, assert_warns_message, assert_no_warnings) +from sklearn.utils import _array_indexing from sklearn.utils import check_random_state from sklearn.utils import _check_key_type from sklearn.utils import deprecated @@ -387,6 +388,13 @@ def test_safe_indexing_mask_axis_1(array_type): assert_allclose_dense_sparse(X_masked, X_masked_array) +def test_array_indexing_array_error(): + X = np.array([[0, 1], [2, 3]]) + mask = [True, False] + with pytest.raises(ValueError, match="'axis' should be either 0"): + _array_indexing(X, mask, axis=3) + + def test_shuffle_on_ndim_equals_three(): def to_tuple(A): # to make the inner arrays hashable return tuple(tuple(tuple(C) for C in B) for B in A) From 5dcf34f3979c365f99b94b4f4a116c2ec4ed18b2 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Aug 2019 14:40:02 +0200 Subject: [PATCH 12/45] itert --- sklearn/utils/__init__.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index c41325c03d307..901ad3561dd5c 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -268,10 +268,10 @@ def _safe_indexing_row(X, indices): elif not isinstance(indices, slice): indices = np.asarray(indices) if hasattr(X, "iloc"): + # Pandas Dataframes and Series if not isinstance(indices, slice): # Work-around for indexing with read-only indices in pandas indices = indices if indices.flags.writeable else indices.copy() - # Pandas Dataframes and Series try: return X.iloc[indices] except ValueError: @@ -281,20 +281,14 @@ def _safe_indexing_row(X, indices): DataConversionWarning) return X.copy().iloc[indices] elif hasattr(X, "shape"): - if hasattr(X, 'take') and (hasattr(indices, 'dtype') and - indices.dtype.kind == 'i'): - # This is often substantially faster than X[indices] - return X.take(indices, axis=0) - else: - return _array_indexing(X, indices, axis=0) - else: + return _array_indexing(X, indices, axis=0) + elif not isinstance(indices, Iterable) or indices.ndim == 0: # In the case of a slice or a scalar - if not isinstance(indices, Iterable) or indices.ndim == 0: - return X[indices] - else: - if np.issubdtype(indices.dtype, np.bool_): - indices = np.flatnonzero(indices) - return [X[idx] for idx in indices] + return X[indices] + else: + if np.issubdtype(indices.dtype, np.bool_): + indices = np.flatnonzero(indices) + return [X[idx] for idx in indices] def _check_key_type(key, superclass): From 70f0e023b76558bad6226aa92a25c1024870e839 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Aug 2019 16:52:10 +0200 Subject: [PATCH 13/45] iter --- sklearn/utils/__init__.py | 68 +++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 901ad3561dd5c..cad6ecaa64f80 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -4,6 +4,7 @@ from collections.abc import Iterable from collections.abc import Sequence from contextlib import contextmanager +from itertools import compress from itertools import islice import numbers import platform @@ -188,19 +189,17 @@ def safe_indexing(X, indices, axis=0): ---------- X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series Data from which to sample rows, items or columns. - indices : array-like - - For both `axis=0` and `axis=1`, indices can be one of: - - scalar: output is 1D, unless `X` is sparse. - Supported data types for scalars: - - integer: supported for arrays, sparse matrices and - dataframes. - - container: lists, slices, boolean masks: output is 2D. - Supported data types for containers: - - integer or boolean (positional): supported for arrays, sparse - matrices and dataframes - - If `X` is a dataframe and `axis=1`, indices support string data type - (key-based) as a scalar or a container. The output dimension will be - identical to the above case. + indices : bool, int, str, array-like + - To select a single element (i.e. row or column), `indices` can be one + of the following: `bool` or `int` which are supported by all types of + `X`. `indices` being a `str` is only supported for `X` being a + dataframe. The selected subset will be 1D, unless `X` is a sparse + matrix in which case it will be 2D. + - To select multiple elements (i.e. rows or columns), `indices` can be + one of the following: `list`, `array`, `slice`. The type used in + these containers can be one of the following: `int`, `bool`, and + `str`. `str` is only supported when `X` is a dataframe. + The selected subset will be 2D. axis : int, default=0 The axis along which `X` will be subsampled. ``axis=0`` will select rows while ``axis=1`` will select columns. @@ -241,17 +240,18 @@ def _array_indexing(array, key, axis=0): return array[key] if axis == 0 else array[:, key] -def _safe_indexing_row(X, indices): +def _safe_indexing_row(X, key): """Return items or rows from X using indices. - Allows simple indexing of lists or arrays. + Allows simple indexing of lists, NumPy array, SciPy sparse matrices, and + Pandas DataFrame`. Parameters ---------- X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series Data from which to sample rows or items. - indices : array-like of int - Indices according to which X will be subsampled. + key : int, slice, bool or int array-like + Key used to get a subset of X. Returns ------- @@ -263,32 +263,36 @@ def _safe_indexing_row(X, indices): CSR, CSC, and LIL sparse matrices are supported. COO sparse matrices are not supported. """ - if indices is None: + if key is None: return X - elif not isinstance(indices, slice): - indices = np.asarray(indices) if hasattr(X, "iloc"): # Pandas Dataframes and Series - if not isinstance(indices, slice): - # Work-around for indexing with read-only indices in pandas - indices = indices if indices.flags.writeable else indices.copy() + if hasattr(key, 'flags'): + # Work-around for indexing with read-only key in pandas + key = key if key.flags.writeable else key.copy() try: - return X.iloc[indices] + return X.iloc[key] except ValueError: # Cython typed memoryviews internally used in pandas do not support # readonly buffers. warnings.warn("Copying input dataframe for slicing.", DataConversionWarning) - return X.copy().iloc[indices] + return X.copy().iloc[key] elif hasattr(X, "shape"): - return _array_indexing(X, indices, axis=0) - elif not isinstance(indices, Iterable) or indices.ndim == 0: - # In the case of a slice or a scalar - return X[indices] + # NumPy array and SciPy sparse matrix + return _array_indexing(X, key, axis=0) else: - if np.issubdtype(indices.dtype, np.bool_): - indices = np.flatnonzero(indices) - return [X[idx] for idx in indices] + # Python list + if not isinstance(key, Iterable) or isinstance(indexable, slice): + # key being a slice or a scalar + return X[key] + key_set = set(key) + if len(key_set) == 2 and all(isinstance(k, (bool, np.bool_)) + for k in key_set): + # key is a boolean array-like + return list(compress(X, key)) + # key is a integer array-like of key + return [X[idx] for idx in key] def _check_key_type(key, superclass): From 7127b5a9cf2efcc01ac8c9028111d74a33fe1088 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Aug 2019 18:31:00 +0200 Subject: [PATCH 14/45] refactor --- sklearn/utils/__init__.py | 202 +++++++++++------------------- sklearn/utils/tests/test_utils.py | 7 -- 2 files changed, 71 insertions(+), 138 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index cad6ecaa64f80..5a41c4a8bd298 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -182,56 +182,8 @@ def axis0_safe_slice(X, mask, len_mask): return np.zeros(shape=(0, X.shape[1])) -def safe_indexing(X, indices, axis=0): - """Return rows, items or columns of X using indices. - - Parameters - ---------- - X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series - Data from which to sample rows, items or columns. - indices : bool, int, str, array-like - - To select a single element (i.e. row or column), `indices` can be one - of the following: `bool` or `int` which are supported by all types of - `X`. `indices` being a `str` is only supported for `X` being a - dataframe. The selected subset will be 1D, unless `X` is a sparse - matrix in which case it will be 2D. - - To select multiple elements (i.e. rows or columns), `indices` can be - one of the following: `list`, `array`, `slice`. The type used in - these containers can be one of the following: `int`, `bool`, and - `str`. `str` is only supported when `X` is a dataframe. - The selected subset will be 2D. - axis : int, default=0 - The axis along which `X` will be subsampled. ``axis=0`` will select - rows while ``axis=1`` will select columns. - - Returns - ------- - subset - Subset of X on axis 0 or 1. - - Notes - ----- - CSR, CSC, and LIL sparse matrices are supported. COO sparse matrices are - not supported. - """ - if axis == 0: - return _safe_indexing_row(X, indices) - elif axis == 1: - return _safe_indexing_column(X, indices) - else: - raise ValueError( - "'axis' should be either 0 (to index rows) or 1 (to index " - " column). Got {} instead.".format(axis) - ) - - -def _array_indexing(array, key, axis=0): +def _array_indexing(array, key, axis): """Index an array consistently across NumPy version.""" - if axis not in (0, 1): - raise ValueError( - "'axis' should be either 0 (to index rows) or 1 (to index " - " column). Got {} instead.".format(axis) - ) if np_version < (1, 13) or issparse(array): # check if we have an boolean array-likes to make the proper indexing key_array = np.asarray(key) @@ -240,48 +192,18 @@ def _array_indexing(array, key, axis=0): return array[key] if axis == 0 else array[:, key] -def _safe_indexing_row(X, key): - """Return items or rows from X using indices. +def _pandas_indexing(X, key, axis, by_name): + """Index a pandas dataframe or a series.""" + if hasattr(key, 'flags'): + # Work-around for indexing with read-only key in pandas + key = key if key.flags.writeable else key.copy() + indexer = 'loc' if by_name else 'iloc' + return getattr(X, indexer)[:, key] if axis else getattr(X, indexer)[key] - Allows simple indexing of lists, NumPy array, SciPy sparse matrices, and - Pandas DataFrame`. - Parameters - ---------- - X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series - Data from which to sample rows or items. - key : int, slice, bool or int array-like - Key used to get a subset of X. - - Returns - ------- - subset - Subset of X on first axis. - - Notes - ----- - CSR, CSC, and LIL sparse matrices are supported. COO sparse matrices are - not supported. - """ - if key is None: - return X - if hasattr(X, "iloc"): - # Pandas Dataframes and Series - if hasattr(key, 'flags'): - # Work-around for indexing with read-only key in pandas - key = key if key.flags.writeable else key.copy() - try: - return X.iloc[key] - except ValueError: - # Cython typed memoryviews internally used in pandas do not support - # readonly buffers. - warnings.warn("Copying input dataframe for slicing.", - DataConversionWarning) - return X.copy().iloc[key] - elif hasattr(X, "shape"): - # NumPy array and SciPy sparse matrix - return _array_indexing(X, key, axis=0) - else: +def _list_indexing(X, key, axis): + """Index a Python list.""" + if axis == 0: # Python list if not isinstance(key, Iterable) or isinstance(indexable, slice): # key being a slice or a scalar @@ -315,7 +237,7 @@ def _check_key_type(key, superclass): return (isinstance(key.start, (superclass, type(None))) and isinstance(key.stop, (superclass, type(None)))) if isinstance(key, list): - return all(isinstance(x, superclass) for x in key) + return all(isinstance(x, superclass) for x in set(key)) if hasattr(key, 'dtype'): if superclass is int: return key.dtype.kind == 'i' @@ -327,60 +249,78 @@ def _check_key_type(key, superclass): return False -def _safe_indexing_column(X, key): - """Get feature column(s) from input data X. +def safe_indexing(X, indices, axis=0): + """Return rows, items or columns of X using indices. - Supported input types (X): numpy arrays, sparse arrays and DataFrames. + Parameters + ---------- + X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series + Data from which to sample rows, items or columns. + indices : bool, int, str, array-like + - To select a single element (i.e. row or column), `indices` can be one + of the following: `bool` or `int` which are supported by all types of + `X`. `indices` being a `str` is only supported for `X` being a + dataframe. The selected subset will be 1D, unless `X` is a sparse + matrix in which case it will be 2D. + - To select multiple elements (i.e. rows or columns), `indices` can be + one of the following: `list`, `array`, `slice`. The type used in + these containers can be one of the following: `int`, `bool`, and + `str`. `str` is only supported when `X` is a dataframe. + The selected subset will be 2D. + axis : int, default=0 + The axis along which `X` will be subsampled. ``axis=0`` will select + rows while ``axis=1`` will select columns. - Supported key types (key): - - scalar: output is 1D; - - lists, slices, boolean masks: output is 2D. + Returns + ------- + subset + Subset of X on axis 0 or 1. - Supported key data types: - - integer or boolean mask (positional): - - supported for arrays, sparse matrices and dataframes. - - string (key-based): - - only supported for dataframes; - - So no keys other than strings are allowed (while in principle you - can use any hashable object as key). + Notes + ----- + CSR, CSC, and LIL sparse matrices are supported. COO sparse matrices are + not supported. """ - # check that X is a 2D structure - if X.ndim != 2: - raise ValueError( - "'X' should be a 2D NumPy array, 2D sparse matrix or pandas " - "dataframe when indexing the columns (i.e. 'axis=1'). " - "Got {} instead with {} dimension(s).".format(type(X), X.ndim) - ) - # check whether we have string column names or integers - if _check_key_type(key, int): - column_names = False - elif _check_key_type(key, str): - column_names = True - elif hasattr(key, 'dtype') and np.issubdtype(key.dtype, np.bool_): + if indices is None: + return X + if _check_key_type(indices, int): + by_name = False + elif _check_key_type(indices, str): + by_name = True + elif _check_key_type(indices, bool): # boolean mask - column_names = False + by_name = False if hasattr(X, 'loc'): # pandas boolean masks don't work with iloc, so take loc path - column_names = True + by_name = True else: raise ValueError("No valid specification of the columns. Only a " "scalar, list or slice of all integers or all " "strings, or boolean mask is allowed") - if column_names: - if hasattr(X, 'loc'): - # pandas dataframes - return X.loc[:, key] - else: - raise ValueError("Specifying the columns using strings is only " - "supported for pandas DataFrames") + if axis not in (0, 1): + raise ValueError( + "'axis' should be either 0 (to index rows) or 1 (to index " + " column). Got {} instead.".format(axis) + ) + + if axis == 1 and X.ndim != 2: + raise ValueError( + "'X' should be a 2D NumPy array, 2D sparse matrix or pandas " + "dataframe when indexing the columns (i.e. 'axis=1'). " + "Got {} instead with {} dimension(s).".format(type(X), X.ndim) + ) + + if by_name and not hasattr(X, 'loc'): + raise ValueError("Specifying the columns using strings is only " + "supported for pandas DataFrames") + + if hasattr(X, "iloc"): + return _pandas_indexing(X, indices, axis=axis, by_name=by_name) + elif hasattr(X, "shape"): + return _array_indexing(X, indices, axis=axis) else: - if hasattr(X, 'iloc'): - # pandas dataframes - return X.iloc[:, key] - else: - # numpy arrays, sparse arrays - return _array_indexing(X, key, axis=1) + return _list_indexing(X, indices, axis=axis) def _get_column_indices(X, key): diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index da7306c1f2a1b..73b342d0e5e50 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -442,13 +442,6 @@ def test_safe_indexing_mask_axis_1(array_type): assert_allclose_dense_sparse(X_masked, X_masked_array) -def test_array_indexing_array_error(): - X = np.array([[0, 1], [2, 3]]) - mask = [True, False] - with pytest.raises(ValueError, match="'axis' should be either 0"): - _array_indexing(X, mask, axis=3) - - def test_shuffle_on_ndim_equals_three(): def to_tuple(A): # to make the inner arrays hashable return tuple(tuple(tuple(C) for C in B) for B in A) From 2f96882bbf9619c4c5cc4b08de12dec7590e1251 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Aug 2019 18:39:02 +0200 Subject: [PATCH 15/45] PEP8 comments --- sklearn/utils/__init__.py | 6 +++--- sklearn/utils/tests/test_utils.py | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 5a41c4a8bd298..47b704a59d12d 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -256,7 +256,7 @@ def safe_indexing(X, indices, axis=0): ---------- X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series Data from which to sample rows, items or columns. - indices : bool, int, str, array-like + indices : bool, int, str, slice, array-like - To select a single element (i.e. row or column), `indices` can be one of the following: `bool` or `int` which are supported by all types of `X`. `indices` being a `str` is only supported for `X` being a @@ -264,8 +264,8 @@ def safe_indexing(X, indices, axis=0): matrix in which case it will be 2D. - To select multiple elements (i.e. rows or columns), `indices` can be one of the following: `list`, `array`, `slice`. The type used in - these containers can be one of the following: `int`, `bool`, and - `str`. `str` is only supported when `X` is a dataframe. + these containers can be one of the following: `int` and `str`. + However, `str` is only supported when `X` is a dataframe. The selected subset will be 2D. axis : int, default=0 The axis along which `X` will be subsampled. ``axis=0`` will select diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 73b342d0e5e50..6ebc378890bf3 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -13,7 +13,6 @@ assert_allclose_dense_sparse, assert_raises_regex, assert_warns_message, assert_no_warnings) -from sklearn.utils import _array_indexing from sklearn.utils import check_random_state from sklearn.utils import _check_key_type from sklearn.utils import deprecated From 619fb0526c5b905e31ab11da580c197b505690ac Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Aug 2019 18:41:00 +0200 Subject: [PATCH 16/45] iter --- sklearn/utils/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 47b704a59d12d..29e1b324ea9c6 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -255,7 +255,8 @@ def safe_indexing(X, indices, axis=0): Parameters ---------- X : array-like, sparse-matrix, list, pandas.DataFrame, pandas.Series - Data from which to sample rows, items or columns. + Data from which to sample rows, items or columns. `list` are only + supported when `axis=0`. indices : bool, int, str, slice, array-like - To select a single element (i.e. row or column), `indices` can be one of the following: `bool` or `int` which are supported by all types of From b7539bd75ea5bd8276621b16aa4f15165e66a71a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Aug 2019 18:46:55 +0200 Subject: [PATCH 17/45] style --- sklearn/utils/__init__.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 29e1b324ea9c6..9e6396e7b5ba1 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -201,20 +201,18 @@ def _pandas_indexing(X, key, axis, by_name): return getattr(X, indexer)[:, key] if axis else getattr(X, indexer)[key] -def _list_indexing(X, key, axis): +def _list_indexing(X, key): """Index a Python list.""" - if axis == 0: - # Python list - if not isinstance(key, Iterable) or isinstance(indexable, slice): - # key being a slice or a scalar - return X[key] - key_set = set(key) - if len(key_set) == 2 and all(isinstance(k, (bool, np.bool_)) - for k in key_set): - # key is a boolean array-like - return list(compress(X, key)) - # key is a integer array-like of key - return [X[idx] for idx in key] + if not isinstance(key, Iterable) or isinstance(indexable, slice): + # key is a slice or a scalar + return X[key] + key_set = set(key) + if (len(key_set) == 2 and + all(isinstance(k, (bool, np.bool_)) for k in key_set)): + # key is a boolean array-like + return list(compress(X, key)) + # key is a integer array-like of key + return [X[idx] for idx in key] def _check_key_type(key, superclass): @@ -321,7 +319,7 @@ def safe_indexing(X, indices, axis=0): elif hasattr(X, "shape"): return _array_indexing(X, indices, axis=axis) else: - return _list_indexing(X, indices, axis=axis) + return _list_indexing(X, indices) def _get_column_indices(X, key): From 92d1aaf596745c2abe7a3bce9484d34852505d92 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 1 Aug 2019 19:27:04 +0200 Subject: [PATCH 18/45] iter --- sklearn/utils/__init__.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 9e6396e7b5ba1..3ea68ac4ef470 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -194,11 +194,21 @@ def _array_indexing(array, key, axis): def _pandas_indexing(X, key, axis, by_name): """Index a pandas dataframe or a series.""" - if hasattr(key, 'flags'): + if hasattr(key, 'shape'): # Work-around for indexing with read-only key in pandas key = key if key.flags.writeable else key.copy() indexer = 'loc' if by_name else 'iloc' - return getattr(X, indexer)[:, key] if axis else getattr(X, indexer)[key] + try: + return (getattr(X, indexer)[:, key] + if axis else getattr(X, indexer)[key]) + except ValueError: + # Cython typed memoryviews internally used in pandas do not support + # readonly buffers. + warnings.warn( + "Copying input dataframe for slicing.", DataConversionWarning + ) + return (getattr(X.copy(), indexer)[:, key] + if axis else getattr(X.copy(), indexer)[key]) def _list_indexing(X, key): From b1918e83de705b97b789d970ecd471903141182c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 2 Aug 2019 15:33:41 +0200 Subject: [PATCH 19/45] address different comments --- doc/whats_new/v0.22.rst | 4 ++-- sklearn/compose/tests/test_column_transformer.py | 2 +- sklearn/utils/__init__.py | 2 +- sklearn/utils/tests/test_utils.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index db9e0e574da06..0f3c5665e3aa6 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -65,8 +65,8 @@ Changelog ...................... - |Fix| Fixed a bug in :class:`compose.ColumnTransformer` which failed to - select the proper columns when using a boolean list and NumPy older than - 1.13. + select the proper columns when using a boolean list, with NumPy older than + 1.12. :pr:`14510` by :user:`Guillaume Lemaitre `. :mod:`sklearn.datasets` diff --git a/sklearn/compose/tests/test_column_transformer.py b/sklearn/compose/tests/test_column_transformer.py index a667b35cf65e3..d28a82374ad5b 100644 --- a/sklearn/compose/tests/test_column_transformer.py +++ b/sklearn/compose/tests/test_column_transformer.py @@ -1114,7 +1114,7 @@ def test_column_transformer_reordered_column_names_remainder(explicit_colname): @pytest.mark.parametrize("array_type", [np.asarray, sparse.csr_matrix]) def test_column_transformer_mask_indexing(array_type): # Regression test for #14510 - # Boolean array-like does not behave as boolean array with NumPy < 1.13 + # Boolean array-like does not behave as boolean array with NumPy < 1.12 # and sparse matrices as well X = np.transpose([[1, 2, 3], [4, 5, 6], [5, 6, 7], [8, 9, 10]]) X = array_type(X) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 3b4a20d08716b..f95a0d6cccc57 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -233,7 +233,7 @@ def _array_indexing(array, key, axis=0): "'axis' should be either 0 (to index rows) or 1 (to index " " column). Got {} instead.".format(axis) ) - if np_version < (1, 13) or issparse(array): + if np_version < (1, 12) or issparse(array): # check if we have an boolean array-likes to make the proper indexing key_array = np.asarray(key) if np.issubdtype(key_array.dtype, np.bool_): diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 49f50eedc0a42..806295f1aae28 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -371,7 +371,7 @@ def test_safe_indexing_mock_pandas(asarray): def test_safe_indexing_mask_axis_1(array_type): # regression test for #14510 # check that boolean array-like and boolean array lead to the same indexing - # even in NumPy < 1.13 + # even in NumPy < 1.12 if array_type == 'array': array_constructor = np.asarray elif array_type == 'sparse': From 6322f99d5d85918643e2ba1b86539cb654d8cb60 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 2 Aug 2019 15:54:55 +0200 Subject: [PATCH 20/45] iter --- sklearn/utils/__init__.py | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 19e5e78b27cf3..638dc2e6f64fc 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -185,6 +185,7 @@ def axis0_safe_slice(X, mask, len_mask): def _array_indexing(array, key, axis): """Index an array consistently across NumPy version.""" if np_version < (1, 12) or issparse(array): + # FIXME: Remove the check for NumPy when using >= 1.12 # check if we have an boolean array-likes to make the proper indexing key_array = np.asarray(key) if np.issubdtype(key_array.dtype, np.bool_): @@ -196,19 +197,11 @@ def _pandas_indexing(X, key, axis, by_name): """Index a pandas dataframe or a series.""" if hasattr(key, 'shape'): # Work-around for indexing with read-only key in pandas + # FIXME: solved in pandas 0.25 + key = np.asarray(key) key = key if key.flags.writeable else key.copy() indexer = 'loc' if by_name else 'iloc' - try: - return (getattr(X, indexer)[:, key] - if axis else getattr(X, indexer)[key]) - except ValueError: - # Cython typed memoryviews internally used in pandas do not support - # readonly buffers. - warnings.warn( - "Copying input dataframe for slicing.", DataConversionWarning - ) - return (getattr(X.copy(), indexer)[:, key] - if axis else getattr(X.copy(), indexer)[key]) + return (getattr(X, indexer)[:, key] if axis else getattr(X, indexer)[key]) def _list_indexing(X, key): From 4d4cc2db23c47601ae9d0af53cdca5e80330b3e2 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 2 Aug 2019 17:12:41 +0200 Subject: [PATCH 21/45] update error message --- sklearn/utils/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 638dc2e6f64fc..d85d755c29eb4 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -340,7 +340,8 @@ def _get_column_indices(X, key): idx = safe_indexing(np.arange(n_columns), key) except IndexError as e: raise ValueError( - 'all features must be in [0, %d]' % (n_columns - 1) + 'all features must be in [0, {}] or [-{}, 0]' + .format(n_columns - 1, n_columns) ) from e return np.atleast_1d(idx).tolist() elif _check_key_type(key, str): From 46f96a95f94c610a036a6fbeffe206a3774bb38f Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 7 Aug 2019 11:07:32 +0200 Subject: [PATCH 22/45] refactor --- sklearn/utils/__init__.py | 59 ++++-- sklearn/utils/tests/test_utils.py | 318 +++++++++++++++--------------- 2 files changed, 200 insertions(+), 177 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index d85d755c29eb4..d2ddaf9e27e57 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -193,8 +193,24 @@ def _array_indexing(array, key, axis): return array[key] if axis == 0 else array[:, key] -def _pandas_indexing(X, key, axis, by_name): +def _pandas_indexing(X, key, axis): """Index a pandas dataframe or a series.""" + # check whether we should index with loc or iloc + if _check_key_type(key, int): + by_name = False + elif _check_key_type(key, str): + by_name = True + elif _check_key_type(key, bool): + # boolean mask + by_name = False + if hasattr(X, 'loc'): + # pandas boolean masks don't work with iloc, so take loc path + by_name = True + else: + raise ValueError("No valid specification of the columns. Only a " + "scalar, list or slice of all integers or all " + "strings, or boolean mask is allowed") + if hasattr(key, 'shape'): # Work-around for indexing with read-only key in pandas # FIXME: solved in pandas 0.25 @@ -238,7 +254,14 @@ def _check_key_type(key, superclass): return (isinstance(key.start, (superclass, type(None))) and isinstance(key.stop, (superclass, type(None)))) if isinstance(key, list): - return all(isinstance(x, superclass) for x in set(key)) + unique_key = set(key) + all_superclass = all(isinstance(x, superclass) for x in unique_key) + if superclass != int: + return all_superclass + # bool is a subclass of int, therefore we should check specifically for + # any bool + any_bool = any(isinstance(x, bool) for x in unique_key) + return all_superclass and not any_bool if hasattr(key, 'dtype'): if superclass is int: return key.dtype.kind == 'i' @@ -285,20 +308,6 @@ def safe_indexing(X, indices, axis=0): """ if indices is None: return X - if _check_key_type(indices, int): - by_name = False - elif _check_key_type(indices, str): - by_name = True - elif _check_key_type(indices, bool): - # boolean mask - by_name = False - if hasattr(X, 'loc'): - # pandas boolean masks don't work with iloc, so take loc path - by_name = True - else: - raise ValueError("No valid specification of the columns. Only a " - "scalar, list or slice of all integers or all " - "strings, or boolean mask is allowed") if axis not in (0, 1): raise ValueError( @@ -306,6 +315,14 @@ def safe_indexing(X, indices, axis=0): " column). Got {} instead.".format(axis) ) + if (axis == 0 and not + (isinstance(indices, Iterable) and _check_key_type(indices, int) or + isinstance(indices, int))): + raise ValueError( + "'axis=0' only support integer array-like or scalar integer " + "as indices. Got {} instead.".format(indices) + ) + if axis == 1 and X.ndim != 2: raise ValueError( "'X' should be a 2D NumPy array, 2D sparse matrix or pandas " @@ -313,12 +330,14 @@ def safe_indexing(X, indices, axis=0): "Got {} instead with {} dimension(s).".format(type(X), X.ndim) ) - if by_name and not hasattr(X, 'loc'): - raise ValueError("Specifying the columns using strings is only " - "supported for pandas DataFrames") + if axis == 1 and _check_key_type(indices, str) and not hasattr(X, 'loc'): + raise ValueError( + "Specifying the columns using strings is only supported for " + "pandas DataFrames" + ) if hasattr(X, "iloc"): - return _pandas_indexing(X, indices, axis=axis, by_name=by_name) + return _pandas_indexing(X, indices, axis=axis) elif hasattr(X, "shape"): return _array_indexing(X, indices, axis=axis) else: diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 0f2e1a526d711..1c97f3b71bb13 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -209,6 +209,7 @@ def test_column_or_1d(): (np.array([0, 1, 2], dtype=np.int64), int, True), (np.array([0, 1, 2], dtype=np.uint8), int, False), ([True, False], bool, True), + ([True, False], int, False), (np.array([True, False]), bool, True), (np.array([True, False]), int, False), ('col_0', str, True), @@ -221,140 +222,185 @@ def test_check_key_type(key, clazz, is_expected_type): assert _check_key_type(key, clazz) is is_expected_type +def _convert_container(container, constructor_name, columns_name=None): + if constructor_name == 'list': + return list(container) + elif constructor_name == 'array': + return np.asarray(container) + elif constructor_name == 'sparse': + return sp.csr_matrix(container) + elif constructor_name == 'dataframe': + pd = pytest.importorskip('pandas') + return pd.DataFrame(container, columns=columns_name) + elif constructor_name == 'series': + pd = pytest.importorskip('pandas') + return pd.Series(container) + elif constructor_name == 'slice': + return slice(container[0], container[1]) + + @pytest.mark.parametrize( - "idx", - [[0, 2], [True, False, True], # array-like - np.array([0, 2]), np.array([True, False, True])], # numpy array - ids=['list-indices', 'list-mask', 'array-indices', 'array-mask'] + "array_type", ["list", "array", "sparse", "dataframe"] ) @pytest.mark.parametrize( - "array_type", [None, np.asarray, sp.csr_matrix], - ids=["list", "array", "sparse"] + "indices", + [_convert_container([1, 2], "list"), + _convert_container([1, 2], "array"), + _convert_container([1, 2], "series")] ) -def test_safe_indexing_axis_0_container(idx, array_type): - X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - X = array_type(X) if array_type is not None else X - X_subset = safe_indexing(X, idx, axis=0) - X_expect = [[1, 2, 3], [7, 8, 9]] - X_expect = array_type(X_expect) if array_type is not None else X_expect - assert_allclose_dense_sparse(X_subset, X_expect) +def test_safe_indexing_2d_container_axis_0(array_type, indices): + array = _convert_container([[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type) + subset = safe_indexing(array, indices, axis=0) + assert_allclose_dense_sparse( + subset, _convert_container([[4, 5, 6], [7, 8, 9]], array_type) + ) -@pytest.mark.parametrize( - "array_type", [list, np.asarray, sp.csr_matrix], - ids=["list", "array", "sparse"] -) -def test_safe_indexing_axis_0_slice(array_type): - X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - X = array_type(X) - idx = slice(0, 2) - X_subset = safe_indexing(X, idx, axis=0) - X_expect = [[1, 2, 3], [4, 5, 6]] - X_expect = array_type(X_expect) if array_type is not None else X_expect - assert_allclose_dense_sparse(X_subset, X_expect) +@pytest.mark.parametrize("array_type", ["list", "array", "series"]) +@pytest.mark.parametrize("indices_type", ["list", "array", "series"]) +def test_safe_indexing_1d_container(array_type, indices_type): + array = _convert_container([1, 2, 3, 4, 5, 6, 7, 8, 9], array_type) + indices = _convert_container([1, 2], indices_type) + subset = safe_indexing(array, indices, axis=0) + assert_allclose_dense_sparse( + subset, _convert_container([2, 3], array_type) + ) -@pytest.mark.parametrize( - "array_type", [list, np.asarray, sp.csr_matrix], - ids=["list", "array", "sparse"] -) -def test_safe_indexing_axis_0_scalar(array_type): - X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - X = array_type(X) - idx = 1 # scalar indexing - X_subset = safe_indexing(X, idx, axis=0) - X_expect = [4, 5, 6] - X_expect = array_type(X_expect) if array_type is not None else X_expect - assert_allclose_dense_sparse(X_subset, X_expect) +@pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe"]) +@pytest.mark.parametrize("indices_type", ["list", "array", "series", "slice"]) +@pytest.mark.parametrize("indices", [[1, 2], ["col_1", "col_2"]]) +def test_safe_indexing_2d_container_axis_1(array_type, indices_type, indices): + # validation of the indices + if indices_type == 'slice' and isinstance(indices[1], int): + indices[1] += 1 + + columns_name = ['col_0', 'col_1', 'col_2'] + array = _convert_container( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type, columns_name + ) + indices_converted = _convert_container(indices, indices_type) + + if isinstance(indices[0], str) and array_type != 'dataframe': + err_msg = ("Specifying the columns using strings is only supported " + "for pandas DataFrames") + with pytest.raises(ValueError, match=err_msg): + safe_indexing(array, indices_converted, axis=1) + else: + subset = safe_indexing(array, indices_converted, axis=1) + assert_allclose_dense_sparse( + subset, _convert_container([[2, 3], [5, 6], [8, 9]], array_type) + ) +@pytest.mark.parametrize("array_read_only", [True, False]) +@pytest.mark.parametrize("indices_read_only", [True, False]) +@pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe"]) +@pytest.mark.parametrize("indices_type", ["array", "series"]) @pytest.mark.parametrize( - "array_type", [None, np.asarray, sp.csr_matrix], - ids=["list", "array", "sparse"] + "axis, expected_array", + [(0, [[4, 5, 6], [7, 8, 9]]), (1, [[2, 3], [5, 6], [8, 9]])] ) -def test_safe_indexing_axis_0_None(array_type): - X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - X = array_type(X) if array_type is not None else X - X_subset = safe_indexing(X, None, axis=0) - assert_allclose_dense_sparse(X_subset, X) - +def test_safe_indexing_2d_read_only_axis_1(array_read_only, indices_read_only, + array_type, indices_type, axis, + expected_array): + array = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + if array_read_only: + array.setflags(write=False) + array = _convert_container(array, array_type) + indices = np.array([1, 2]) + if indices_read_only: + indices.setflags(write=False) + indices = _convert_container(indices, indices_type) + subset = safe_indexing(array, indices, axis=axis) + assert_allclose_dense_sparse( + subset, _convert_container(expected_array, array_type) + ) -@pytest.mark.parametrize("idx", [0, [0, 1]], ids=['scalar', 'list']) -@pytest.mark.parametrize("asarray", [True, False], ids=["array-like", "array"]) -def test_safe_indexing_axis_1_sparse(idx, asarray): - if isinstance(idx, Iterable) and asarray: - idx = np.asarray(idx) - X_true = safe_indexing(X_toy, idx, axis=1) - # scipy matrix will always return a 2D array - if X_true.ndim == 1: - X_true = X_true[:, np.newaxis] +@pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe"]) +@pytest.mark.parametrize("indices_type", ["list", "array", "series"]) +def test_safe_indexing_2d_mask_axis_1(array_type, indices_type): + columns_name = ['col_0', 'col_1', 'col_2'] + array = _convert_container( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type, columns_name + ) + indices = [False, True, True] + indices = _convert_container(indices, indices_type) - X_sparse = sp.csc_matrix(X_toy) - assert_array_equal( - safe_indexing(X_sparse, idx, axis=1).toarray(), X_true + subset = safe_indexing(array, indices, axis=1) + assert_allclose_dense_sparse( + subset, _convert_container([[2, 3], [5, 6], [8, 9]], array_type) ) @pytest.mark.parametrize( - "idx_array, idx_df", - [(0, 0), - (0, 'col_0'), - ([0, 1], [0, 1]), - ([0, 1], ['col_0', 'col_1']), - ([0, 1], slice(0, 2)), - ([1, 2], slice(1, None)), - ([0, 1], [True, True, False])], - ids=['scalar-int', 'scalar-str', 'list-int', 'list-str', 'slice', - 'slice-no-stop', 'mask'] + "array_type, expected_output_type", + [("list", "list"), ("array", "array"), + ("sparse", "sparse"), ("dataframe", "series")] ) -@pytest.mark.parametrize("asarray", [True, False], ids=["array-like", "array"]) -def test_safe_indexing_axis_1_pandas(idx_array, idx_df, asarray): - pd = pytest.importorskip('pandas') - if asarray and isinstance(idx_array, Iterable): - idx_array = np.asarray(idx_array) - if (asarray and (not isinstance(idx_df, str) and - isinstance(idx_df, Iterable))): - idx_df = np.asarray(idx_df) - - X_true = safe_indexing(X_toy, idx_array, axis=1) - X_df = pd.DataFrame(X_toy, columns=['col_{}'.format(i) for i in range(3)]) - assert_array_equal( - safe_indexing(X_df, idx_df, axis=1).values, X_true - ) +def test_safe_indexing_2d_scalar_axis_0(array_type, expected_output_type): + array = _convert_container([[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type) + indices = 2 + subset = safe_indexing(array, indices, axis=0) + expected_array = _convert_container([7, 8, 9], expected_output_type) + assert_allclose_dense_sparse(subset, expected_array) -def test_safe_indexing_pandas(): - pd = pytest.importorskip("pandas") - X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - X_df = pd.DataFrame(X) - inds = np.array([1, 2]) - X_df_indexed = safe_indexing(X_df, inds) - X_indexed = safe_indexing(X_df, inds) - assert_array_equal(np.array(X_df_indexed), X_indexed) - # fun with read-only data in dataframes - # this happens in joblib memmapping - X.setflags(write=False) - X_df_readonly = pd.DataFrame(X) - inds_readonly = inds.copy() - inds_readonly.setflags(write=False) - - for this_df, this_inds in product([X_df, X_df_readonly], - [inds, inds_readonly]): - with warnings.catch_warnings(record=True): - X_df_indexed = safe_indexing(this_df, this_inds) - - assert_array_equal(np.array(X_df_indexed), X_indexed) +@pytest.mark.parametrize("array_type", ["list", "array", "series"]) +def test_safe_indexing_1d_scalar(array_type): + array = _convert_container([1, 2, 3, 4, 5, 6, 7, 8, 9], array_type) + indices = 2 + subset = safe_indexing(array, indices, axis=0) + assert subset == 3 + + +@pytest.mark.parametrize( + "array_type, expected_output_type", + [("array", "array"), ("sparse", "sparse"), ("dataframe", "series")] +) +@pytest.mark.parametrize("indices", [2, "col_2"]) +def test_safe_indexing_2d_scalar_axis_1(array_type, expected_output_type, + indices): + columns_name = ['col_0', 'col_1', 'col_2'] + array = _convert_container( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type, columns_name + ) + + if isinstance(indices, str) and array_type != 'dataframe': + err_msg = ("Specifying the columns using strings is only supported " + "for pandas DataFrames") + with pytest.raises(ValueError, match=err_msg): + safe_indexing(array, indices, axis=1) + else: + subset = safe_indexing(array, indices, axis=1) + expected_output = [3, 6, 9] + if expected_output_type == 'sparse': + # sparse matrix are keeping the 2D shape + expected_output = [[3], [6], [9]] + expected_array = _convert_container( + expected_output, expected_output_type + ) + assert_allclose_dense_sparse(subset, expected_array) @pytest.mark.parametrize( - "X, key, err_msg", - [(X_toy, 1.0, "No valid specification of the columns."), - (X_toy, ['col_0'], "Specifying the columns using strings is only")] + "array_type", [None, np.asarray, sp.csr_matrix], + ids=["list", "array", "sparse"] ) -def test_safe_indexing_axis_1_error(X, key, err_msg): +def test_safe_indexing_None_axis_0(array_type): + X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + X = array_type(X) if array_type is not None else X + X_subset = safe_indexing(X, None, axis=0) + assert_allclose_dense_sparse(X_subset, X) + + +def test_safe_indexing_pandas_no_matching_cols_error(): + pd = pytest.importorskip('pandas') + err_msg = "No valid specification of the columns." + X = pd.DataFrame(X_toy) with pytest.raises(ValueError, match=err_msg): - safe_indexing(X, key, axis=1) + safe_indexing(X, [1.0], axis=1) @pytest.mark.parametrize("axis", [None, 3]) @@ -379,6 +425,14 @@ def test_safe_indexing_1d_array_error(X_constructor): safe_indexing(X_constructor, [0, 1], axis=1) +@pytest.mark.parametrize("indices", [["col_1", "col_2"], [True, False, True]]) +def test_safe_indexing_container_axis_0_unsupported_type(indices): + array = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + err_msg = "'axis=0' only support integer array-like or scalar integer" + with pytest.raises(ValueError, match=err_msg): + safe_indexing(array, indices, axis=0) + + @pytest.mark.parametrize( "key, err_msg", [(10, r"all features must be in \[0, 2\]"), @@ -392,56 +446,6 @@ def test_get_column_indices_error(key, err_msg): _get_column_indices(X_df, key) -@pytest.mark.parametrize( - "idx", - [[0, 1], - [True, True, False]] -) -@pytest.mark.parametrize("asarray", [True, False], ids=["array-like", "array"]) -def test_safe_indexing_pandas_series(idx, asarray): - pd = pytest.importorskip("pandas") - idx = np.asarray(idx) if asarray else idx - serie = pd.Series(np.arange(3)) - assert_array_equal(safe_indexing(serie, idx).values, [0, 1]) - - -@pytest.mark.parametrize( - "inds", - [None, - [1, 2], [False, True, True], - np.array([1, 2]), np.array([False, True, True]), - slice(1, None)], - ids=['None', 'list-int', 'list-mask', 'array-int', 'array-mask', 'slice'] -) -def test_safe_indexing_mock_pandas(inds): - X = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - X_df = MockDataFrame(X) - X_df_indexed = safe_indexing(X_df, inds) - X_indexed = safe_indexing(X_df, inds) - assert_array_equal(np.array(X_df_indexed), X_indexed) - - -@pytest.mark.parametrize("array_type", ['array', 'sparse', 'dataframe']) -def test_safe_indexing_mask_axis_1(array_type): - # regression test for #14510 - # check that boolean array-like and boolean array lead to the same indexing - # even in NumPy < 1.12 - if array_type == 'array': - array_constructor = np.asarray - elif array_type == 'sparse': - array_constructor = sp.csr_matrix - elif array_type == 'dataframe': - pd = pytest.importorskip('pandas') - array_constructor = pd.DataFrame - - X = array_constructor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) - mask = [True, False, True] - mask_array = np.array(mask) - X_masked = safe_indexing(X, mask, axis=1) - X_masked_array = safe_indexing(X, mask_array, axis=1) - assert_allclose_dense_sparse(X_masked, X_masked_array) - - def test_shuffle_on_ndim_equals_three(): def to_tuple(A): # to make the inner arrays hashable return tuple(tuple(tuple(C) for C in B) for B in A) From d880075396c74d6e618fa81b735fa1a3402c04a4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 7 Aug 2019 11:14:50 +0200 Subject: [PATCH 23/45] documentation --- sklearn/utils/__init__.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index d2ddaf9e27e57..3359a3f789ecb 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -282,19 +282,20 @@ def safe_indexing(X, indices, axis=0): Data from which to sample rows, items or columns. `list` are only supported when `axis=0`. indices : bool, int, str, slice, array-like - - To select a single element (i.e. row or column), `indices` can be one - of the following: `bool` or `int` which are supported by all types of - `X`. `indices` being a `str` is only supported for `X` being a - dataframe. The selected subset will be 1D, unless `X` is a sparse - matrix in which case it will be 2D. - - To select multiple elements (i.e. rows or columns), `indices` can be - one of the following: `list`, `array`, `slice`. The type used in - these containers can be one of the following: `int` and `str`. - However, `str` is only supported when `X` is a dataframe. - The selected subset will be 2D. + - If `axis=0`, only integer array-like or scalar integer are supported. + - If `axis=1`: + - to select a single column, `indices` can be of `int` type for + all `X` types and `str` only for dataframe. The selected subset + will be 1D, unless `X` is a sparse matrix in which case it will + be 2D. + - to select multiples columns, `indices` can be one of the + following: `list`, `array`, `slice`. The type used in + these containers can be one of the following: `int` and `str`. + However, `str` is only supported when `X` is a dataframe. + The selected subset will be 2D. axis : int, default=0 - The axis along which `X` will be subsampled. ``axis=0`` will select - rows while ``axis=1`` will select columns. + The axis along which `X` will be subsampled. `axis=0` will select + rows while `axis=1` will select columns. Returns ------- From ae868e13f3a8cc17572a23f1629be1cbb2b23ffa Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 7 Aug 2019 11:20:01 +0200 Subject: [PATCH 24/45] whats new --- doc/whats_new/v0.22.rst | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/doc/whats_new/v0.22.rst b/doc/whats_new/v0.22.rst index ad64ef0c67f8e..65d69c022a6fa 100644 --- a/doc/whats_new/v0.22.rst +++ b/doc/whats_new/v0.22.rst @@ -271,12 +271,9 @@ Changelog - |Enhancement| :func:`utils.safe_indexing` accepts an ``axis`` parameter to index array-like across rows and columns. The column indexing can be done on - NumPy array, SciPy sparse matrix, and Pandas DataFrame. - :pr:`14035` by `Guillaume Lemaitre `. - -- |Enhancement| :func:`utils.safe_indexing` becomes more consistent and support - scalar, slice, mask indexing for both `axis=0` and `axis=1`. - :pr:`14475` by :user:`Guillaume Lemaitre `. + NumPy array, SciPy sparse matrix, and Pandas DataFrame. An additional + refactoring was done. + :pr:`14035` and :pr:`14475` by `Guillaume Lemaitre `. :mod:`sklearn.neighbors` ............................. From ead08d60c05675b6d8b5cd994faef092e119984d Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 7 Aug 2019 11:22:35 +0200 Subject: [PATCH 25/45] remove unused import --- sklearn/utils/tests/test_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 1c97f3b71bb13..e98315dc3ec76 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -1,5 +1,4 @@ -from collections.abc import Iterable -from itertools import chain, product +from itertools import chain import warnings import string import timeit @@ -13,7 +12,6 @@ assert_allclose_dense_sparse, assert_raises_regex, assert_warns_message, assert_no_warnings) -from sklearn.utils import _array_indexing from sklearn.utils import check_random_state from sklearn.utils import _check_key_type from sklearn.utils import deprecated From 655a218bce5d20d74f06d233082f03471c1dd6ae Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 7 Aug 2019 11:55:22 +0200 Subject: [PATCH 26/45] fix --- sklearn/utils/__init__.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 3359a3f789ecb..61dab7110c447 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -353,11 +353,17 @@ def _get_column_indices(X, key): """ n_columns = X.shape[1] - if (_check_key_type(key, int) - or hasattr(key, 'dtype') and np.issubdtype(key.dtype, np.bool_)): + if isinstance(key, list) and not key: + # we get an empty list + return [] + elif (_check_key_type(key, int) or _check_key_type(key, bool)): # Convert key into positive indexes try: - idx = safe_indexing(np.arange(n_columns), key) + # The behavior of boolean array-like and boolean array is different + # in NumPy < 1.12. We convert the array-like for consistent + # behavior. + key = np.asarray(key) if _check_key_type(key, bool) else key + idx = np.arange(n_columns)[key] except IndexError as e: raise ValueError( 'all features must be in [0, {}] or [-{}, 0]' From db80c5aaa2c952161c0d85d2293e09f4acc0ad83 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 8 Aug 2019 10:02:52 +0200 Subject: [PATCH 27/45] use supported type in safe_indexing occurences --- sklearn/compose/_target.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/compose/_target.py b/sklearn/compose/_target.py index c1c3f4df4e95f..66f2fbf9ac144 100644 --- a/sklearn/compose/_target.py +++ b/sklearn/compose/_target.py @@ -8,7 +8,6 @@ from ..base import BaseEstimator, RegressorMixin, clone from ..utils.validation import check_is_fitted -from ..utils import check_array, safe_indexing from ..preprocessing import FunctionTransformer __all__ = ['TransformedTargetRegressor'] @@ -139,7 +138,7 @@ def _fit_transformer(self, y): self.transformer_.fit(y) if self.check_inverse: idx_selected = slice(None, None, max(1, y.shape[0] // 10)) - y_sel = safe_indexing(y, idx_selected) + y_sel = y[idx_selected] y_sel_t = self.transformer_.transform(y_sel) if not np.allclose(y_sel, self.transformer_.inverse_transform(y_sel_t)): From ce618fd48067e7a595851eaff035d1e37c6450a8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 8 Aug 2019 10:03:44 +0200 Subject: [PATCH 28/45] save the file --- sklearn/impute/_iterative.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/impute/_iterative.py b/sklearn/impute/_iterative.py index ecf94d5ccfb57..60cf3c57a7a64 100644 --- a/sklearn/impute/_iterative.py +++ b/sklearn/impute/_iterative.py @@ -272,15 +272,17 @@ def _impute_one_feature(self, estimator = clone(self._estimator) if fit_mode: + indices_row_available = np.flatnonzero(~missing_row_mask) X_train = safe_indexing(X_filled[:, neighbor_feat_idx], - ~missing_row_mask) + indices_row_available) y_train = safe_indexing(X_filled[:, feat_idx], - ~missing_row_mask) + indices_row_available) estimator.fit(X_train, y_train) # get posterior samples + indices_row_missing = np.flatnonzero(missing_row_mask) X_test = safe_indexing(X_filled[:, neighbor_feat_idx], - missing_row_mask) + indices_row_missing) if self.sample_posterior: mus, sigmas = estimator.predict(X_test, return_std=True) imputed_values = np.zeros(mus.shape, dtype=X_filled.dtype) From b8e98c23409612fdec393a2ac1e5ff53a3ee13a9 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 8 Aug 2019 10:31:37 +0200 Subject: [PATCH 29/45] ups --- sklearn/compose/_target.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/compose/_target.py b/sklearn/compose/_target.py index 66f2fbf9ac144..ac851a23bae3e 100644 --- a/sklearn/compose/_target.py +++ b/sklearn/compose/_target.py @@ -8,6 +8,7 @@ from ..base import BaseEstimator, RegressorMixin, clone from ..utils.validation import check_is_fitted +from ..utils import check_array from ..preprocessing import FunctionTransformer __all__ = ['TransformedTargetRegressor'] From 11dff0149d09f1340db9de62ab3cce11896adee2 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 8 Aug 2019 10:52:24 +0200 Subject: [PATCH 30/45] change test iid --- sklearn/metrics/cluster/unsupervised.py | 2 +- sklearn/model_selection/tests/test_search.py | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/cluster/unsupervised.py b/sklearn/metrics/cluster/unsupervised.py index 05206ab42a291..43552d77c4e42 100644 --- a/sklearn/metrics/cluster/unsupervised.py +++ b/sklearn/metrics/cluster/unsupervised.py @@ -339,7 +339,7 @@ def davies_bouldin_score(X, labels): intra_dists = np.zeros(n_labels) centroids = np.zeros((n_labels, len(X[0])), dtype=np.float) for k in range(n_labels): - cluster_k = safe_indexing(X, labels == k) + cluster_k = safe_indexing(X, np.flatnonzero(labels == k)) centroid = cluster_k.mean(axis=0) centroids[k] = centroid intra_dists[k] = np.average(pairwise_distances( diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index fe757457b3839..e97af1ccb9b32 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -943,7 +943,10 @@ def test_search_default_iid(SearchCV, specialized_params): # this leads to perfect classification on one fold and a score of 1/3 on # the other # create "cv" for splits - cv = [[mask, ~mask], [~mask, mask]] + first_fold_indices = np.flatnonzero(mask) + second_fold_indices = np.flatnonzero(~mask) + cv = [[first_fold_indices, second_fold_indices], + [second_fold_indices, first_fold_indices]] common_params = {'estimator': SVC(), 'cv': cv, 'return_train_score': True} @@ -992,7 +995,10 @@ def test_search_iid_param(): # this leads to perfect classification on one fold and a score of 1/3 on # the other # create "cv" for splits - cv = [[mask, ~mask], [~mask, mask]] + first_fold_indices = np.flatnonzero(mask) + second_fold_indices = np.flatnonzero(~mask) + cv = [[first_fold_indices, second_fold_indices], + [second_fold_indices, first_fold_indices]] # once with iid=True (default) grid_search = GridSearchCV(SVC(gamma='auto'), param_grid={'C': [1, 10]}, cv=cv, return_train_score=True, iid=True) From 94edba2343b09c317059a920558a9396c2925add Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 8 Aug 2019 11:27:13 +0200 Subject: [PATCH 31/45] gotcha --- sklearn/utils/tests/test_utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index e98315dc3ec76..0a961042348e4 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -1,3 +1,4 @@ +from copy import copy from itertools import chain import warnings import string @@ -270,14 +271,16 @@ def test_safe_indexing_1d_container(array_type, indices_type): @pytest.mark.parametrize("indices", [[1, 2], ["col_1", "col_2"]]) def test_safe_indexing_2d_container_axis_1(array_type, indices_type, indices): # validation of the indices + # we make a copy because indices is mutable and shared between tests + indices_converted = copy(indices) if indices_type == 'slice' and isinstance(indices[1], int): - indices[1] += 1 + indices_converted[1] += 1 columns_name = ['col_0', 'col_1', 'col_2'] array = _convert_container( [[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type, columns_name ) - indices_converted = _convert_container(indices, indices_type) + indices_converted = _convert_container(indices_converted, indices_type) if isinstance(indices[0], str) and array_type != 'dataframe': err_msg = ("Specifying the columns using strings is only supported " From ef4aa2c1ef7a1150773085255348c5da04d7ca61 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 14 Aug 2019 16:52:44 +0200 Subject: [PATCH 32/45] address comments thomas --- sklearn/utils/__init__.py | 9 +++------ sklearn/utils/tests/test_utils.py | 8 ++------ 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 61dab7110c447..61762fa091d6a 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -201,11 +201,8 @@ def _pandas_indexing(X, key, axis): elif _check_key_type(key, str): by_name = True elif _check_key_type(key, bool): - # boolean mask - by_name = False - if hasattr(X, 'loc'): - # pandas boolean masks don't work with iloc, so take loc path - by_name = True + # pandas boolean masks don't work with iloc, so take loc path + by_name = True else: raise ValueError("No valid specification of the columns. Only a " "scalar, list or slice of all integers or all " @@ -222,7 +219,7 @@ def _pandas_indexing(X, key, axis): def _list_indexing(X, key): """Index a Python list.""" - if not isinstance(key, Iterable) or isinstance(indexable, slice): + if not isinstance(key, Iterable): # key is a slice or a scalar return X[key] key_set = set(key) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 0a961042348e4..96d43204806c8 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -385,13 +385,9 @@ def test_safe_indexing_2d_scalar_axis_1(array_type, expected_output_type, assert_allclose_dense_sparse(subset, expected_array) -@pytest.mark.parametrize( - "array_type", [None, np.asarray, sp.csr_matrix], - ids=["list", "array", "sparse"] -) +@pytest.mark.parametrize("array_type", ["list", "array", "sparse"]) def test_safe_indexing_None_axis_0(array_type): - X = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - X = array_type(X) if array_type is not None else X + X = _convert_container([[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type) X_subset = safe_indexing(X, None, axis=0) assert_allclose_dense_sparse(X_subset, X) From a7d29f67e6202dd23d98b3f6cc2a9b4647527efb Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 14 Aug 2019 18:47:59 +0200 Subject: [PATCH 33/45] be explicit --- sklearn/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 61762fa091d6a..5a26fcc5cc97d 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -219,7 +219,7 @@ def _pandas_indexing(X, key, axis): def _list_indexing(X, key): """Index a Python list.""" - if not isinstance(key, Iterable): + if np.isscalar(key) or isinstance(key, slice): # key is a slice or a scalar return X[key] key_set = set(key) From 3624cc56a5ee3ce35582ec02b6e2d78f4d0512db Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 14 Aug 2019 22:56:50 +0200 Subject: [PATCH 34/45] iter --- sklearn/utils/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 5a26fcc5cc97d..3f5155eed6d82 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -287,8 +287,8 @@ def safe_indexing(X, indices, axis=0): be 2D. - to select multiples columns, `indices` can be one of the following: `list`, `array`, `slice`. The type used in - these containers can be one of the following: `int` and `str`. - However, `str` is only supported when `X` is a dataframe. + these containers can be one of the following: `int`, 'bool' and + `str`. However, `str` is only supported when `X` is a dataframe. The selected subset will be 2D. axis : int, default=0 The axis along which `X` will be subsampled. `axis=0` will select From 557aa433cdc6091d574f9419c660ff3f2cbe9337 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Aug 2019 14:43:18 +0200 Subject: [PATCH 35/45] add back support for mask and slice --- sklearn/compose/_target.py | 4 +- sklearn/impute/_iterative.py | 7 ++- sklearn/metrics/cluster/unsupervised.py | 2 +- sklearn/model_selection/tests/test_search.py | 10 +--- sklearn/utils/__init__.py | 18 +++----- sklearn/utils/tests/test_utils.py | 48 ++++++++++++++------ 6 files changed, 49 insertions(+), 40 deletions(-) diff --git a/sklearn/compose/_target.py b/sklearn/compose/_target.py index 502445a86b8f2..35b7ed6af962a 100644 --- a/sklearn/compose/_target.py +++ b/sklearn/compose/_target.py @@ -8,7 +8,7 @@ from ..base import BaseEstimator, RegressorMixin, clone from ..utils.validation import check_is_fitted -from ..utils import check_array +from ..utils import check_array, safe_indexing from ..preprocessing import FunctionTransformer __all__ = ['TransformedTargetRegressor'] @@ -139,7 +139,7 @@ def _fit_transformer(self, y): self.transformer_.fit(y) if self.check_inverse: idx_selected = slice(None, None, max(1, y.shape[0] // 10)) - y_sel = y[idx_selected] + y_sel = safe_indexing(y, idx_selected) y_sel_t = self.transformer_.transform(y_sel) if not np.allclose(y_sel, self.transformer_.inverse_transform(y_sel_t)): diff --git a/sklearn/impute/_iterative.py b/sklearn/impute/_iterative.py index b716aa6c1b37a..1fd60fa0e1768 100644 --- a/sklearn/impute/_iterative.py +++ b/sklearn/impute/_iterative.py @@ -272,17 +272,16 @@ def _impute_one_feature(self, estimator = clone(self._estimator) if fit_mode: - indices_row_available = np.flatnonzero(~missing_row_mask) X_train = safe_indexing(X_filled[:, neighbor_feat_idx], - indices_row_available) + ~missing_row_mask) y_train = safe_indexing(X_filled[:, feat_idx], - indices_row_available) + missing_row_mask) estimator.fit(X_train, y_train) # get posterior samples indices_row_missing = np.flatnonzero(missing_row_mask) X_test = safe_indexing(X_filled[:, neighbor_feat_idx], - indices_row_missing) + missing_row_mask) if self.sample_posterior: mus, sigmas = estimator.predict(X_test, return_std=True) imputed_values = np.zeros(mus.shape, dtype=X_filled.dtype) diff --git a/sklearn/metrics/cluster/unsupervised.py b/sklearn/metrics/cluster/unsupervised.py index 8191f88193162..2aa9c81c49abc 100644 --- a/sklearn/metrics/cluster/unsupervised.py +++ b/sklearn/metrics/cluster/unsupervised.py @@ -350,7 +350,7 @@ def davies_bouldin_score(X, labels): intra_dists = np.zeros(n_labels) centroids = np.zeros((n_labels, len(X[0])), dtype=np.float) for k in range(n_labels): - cluster_k = safe_indexing(X, np.flatnonzero(labels == k)) + cluster_k = safe_indexing(X, labels == k) centroid = cluster_k.mean(axis=0) centroids[k] = centroid intra_dists[k] = np.average(pairwise_distances( diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index e97af1ccb9b32..fe757457b3839 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -943,10 +943,7 @@ def test_search_default_iid(SearchCV, specialized_params): # this leads to perfect classification on one fold and a score of 1/3 on # the other # create "cv" for splits - first_fold_indices = np.flatnonzero(mask) - second_fold_indices = np.flatnonzero(~mask) - cv = [[first_fold_indices, second_fold_indices], - [second_fold_indices, first_fold_indices]] + cv = [[mask, ~mask], [~mask, mask]] common_params = {'estimator': SVC(), 'cv': cv, 'return_train_score': True} @@ -995,10 +992,7 @@ def test_search_iid_param(): # this leads to perfect classification on one fold and a score of 1/3 on # the other # create "cv" for splits - first_fold_indices = np.flatnonzero(mask) - second_fold_indices = np.flatnonzero(~mask) - cv = [[first_fold_indices, second_fold_indices], - [second_fold_indices, first_fold_indices]] + cv = [[mask, ~mask], [~mask, mask]] # once with iid=True (default) grid_search = GridSearchCV(SVC(gamma='auto'), param_grid={'C': [1, 10]}, cv=cv, return_train_score=True, iid=True) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 5a26fcc5cc97d..1bbe15b4e74a9 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -279,7 +279,8 @@ def safe_indexing(X, indices, axis=0): Data from which to sample rows, items or columns. `list` are only supported when `axis=0`. indices : bool, int, str, slice, array-like - - If `axis=0`, only integer array-like or scalar integer are supported. + - If `axis=0`, boolean and integer array-like, integer slice, + and scalar integer are supported. - If `axis=1`: - to select a single column, `indices` can be of `int` type for all `X` types and `str` only for dataframe. The selected subset @@ -313,12 +314,11 @@ def safe_indexing(X, indices, axis=0): " column). Got {} instead.".format(axis) ) - if (axis == 0 and not - (isinstance(indices, Iterable) and _check_key_type(indices, int) or - isinstance(indices, int))): + if axis == 0 and _check_key_type(indices, str): raise ValueError( - "'axis=0' only support integer array-like or scalar integer " - "as indices. Got {} instead.".format(indices) + "'axis=0' only support integer or boolean array-like, slice with " + "integer, or scalar integer as indices. Got {} instead." + .format(indices) ) if axis == 1 and X.ndim != 2: @@ -356,11 +356,7 @@ def _get_column_indices(X, key): elif (_check_key_type(key, int) or _check_key_type(key, bool)): # Convert key into positive indexes try: - # The behavior of boolean array-like and boolean array is different - # in NumPy < 1.12. We convert the array-like for consistent - # behavior. - key = np.asarray(key) if _check_key_type(key, bool) else key - idx = np.arange(n_columns)[key] + idx = safe_indexing(np.arange(n_columns), key) except IndexError as e: raise ValueError( 'all features must be in [0, {}] or [-{}, 0]' diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 96d43204806c8..4e774d6df6884 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -241,14 +241,13 @@ def _convert_container(container, constructor_name, columns_name=None): @pytest.mark.parametrize( "array_type", ["list", "array", "sparse", "dataframe"] ) -@pytest.mark.parametrize( - "indices", - [_convert_container([1, 2], "list"), - _convert_container([1, 2], "array"), - _convert_container([1, 2], "series")] -) -def test_safe_indexing_2d_container_axis_0(array_type, indices): +@pytest.mark.parametrize("indices_type", ["list", "array", "series", "slice"]) +def test_safe_indexing_2d_container_axis_0(array_type, indices_type): + indices = [1, 2] + if indices_type == 'slice' and isinstance(indices[1], int): + indices[1] += 1 array = _convert_container([[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type) + indices = _convert_container(indices, indices_type) subset = safe_indexing(array, indices, axis=0) assert_allclose_dense_sparse( subset, _convert_container([[4, 5, 6], [7, 8, 9]], array_type) @@ -258,8 +257,11 @@ def test_safe_indexing_2d_container_axis_0(array_type, indices): @pytest.mark.parametrize("array_type", ["list", "array", "series"]) @pytest.mark.parametrize("indices_type", ["list", "array", "series"]) def test_safe_indexing_1d_container(array_type, indices_type): + indices = [1, 2] + if indices_type == 'slice' and isinstance(indices[1], int): + indices[1] += 1 array = _convert_container([1, 2, 3, 4, 5, 6, 7, 8, 9], array_type) - indices = _convert_container([1, 2], indices_type) + indices = _convert_container(indices, indices_type) subset = safe_indexing(array, indices, axis=0) assert_allclose_dense_sparse( subset, _convert_container([2, 3], array_type) @@ -319,9 +321,27 @@ def test_safe_indexing_2d_read_only_axis_1(array_read_only, indices_read_only, ) +@pytest.mark.parametrize("array_type", ["list", "array", "series"]) +@pytest.mark.parametrize("indices_type", ["list", "array", "series"]) +def test_safe_indexing_1d_container_mask(array_type, indices_type): + indices = [False] + [True] * 2 + [False] * 6 + array = _convert_container([1, 2, 3, 4, 5, 6, 7, 8, 9], array_type) + indices = _convert_container(indices, indices_type) + subset = safe_indexing(array, indices, axis=0) + assert_allclose_dense_sparse( + subset, _convert_container([2, 3], array_type) + ) + + @pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe"]) @pytest.mark.parametrize("indices_type", ["list", "array", "series"]) -def test_safe_indexing_2d_mask_axis_1(array_type, indices_type): +@pytest.mark.parametrize( + "axis, expected_subset", + [(0, [[4, 5, 6], [7, 8, 9]]), + (1, [[2, 3], [5, 6], [8, 9]])] +) +def test_safe_indexing_2d_mask(array_type, indices_type, axis, + expected_subset): columns_name = ['col_0', 'col_1', 'col_2'] array = _convert_container( [[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type, columns_name @@ -329,9 +349,9 @@ def test_safe_indexing_2d_mask_axis_1(array_type, indices_type): indices = [False, True, True] indices = _convert_container(indices, indices_type) - subset = safe_indexing(array, indices, axis=1) + subset = safe_indexing(array, indices, axis=axis) assert_allclose_dense_sparse( - subset, _convert_container([[2, 3], [5, 6], [8, 9]], array_type) + subset, _convert_container(expected_subset, array_type) ) @@ -422,10 +442,10 @@ def test_safe_indexing_1d_array_error(X_constructor): safe_indexing(X_constructor, [0, 1], axis=1) -@pytest.mark.parametrize("indices", [["col_1", "col_2"], [True, False, True]]) -def test_safe_indexing_container_axis_0_unsupported_type(indices): +def test_safe_indexing_container_axis_0_unsupported_type(): + indices = ["col_1", "col_2"] array = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - err_msg = "'axis=0' only support integer array-like or scalar integer" + err_msg = "'axis=0' only support integer or boolean array-like," with pytest.raises(ValueError, match=err_msg): safe_indexing(array, indices, axis=0) From 698aef2c628a797e2d4b46e21e74adbe0e3a004e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Aug 2019 14:47:00 +0200 Subject: [PATCH 36/45] PEP8 --- sklearn/impute/_iterative.py | 1 - sklearn/utils/__init__.py | 1 - 2 files changed, 2 deletions(-) diff --git a/sklearn/impute/_iterative.py b/sklearn/impute/_iterative.py index 1fd60fa0e1768..7abffc1db389c 100644 --- a/sklearn/impute/_iterative.py +++ b/sklearn/impute/_iterative.py @@ -279,7 +279,6 @@ def _impute_one_feature(self, estimator.fit(X_train, y_train) # get posterior samples - indices_row_missing = np.flatnonzero(missing_row_mask) X_test = safe_indexing(X_filled[:, neighbor_feat_idx], missing_row_mask) if self.sample_posterior: diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 5d2e52cc3d85f..5f384f763e2bb 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -1,7 +1,6 @@ """ The :mod:`sklearn.utils` module includes various utilities. """ -from collections.abc import Iterable from collections.abc import Sequence from contextlib import contextmanager from itertools import compress From 4f2fd8f31d82b074aa8d661df6b438e73c75088c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Aug 2019 14:48:13 +0200 Subject: [PATCH 37/45] typo --- sklearn/impute/_iterative.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/impute/_iterative.py b/sklearn/impute/_iterative.py index 7abffc1db389c..05e2f1484fccf 100644 --- a/sklearn/impute/_iterative.py +++ b/sklearn/impute/_iterative.py @@ -275,7 +275,7 @@ def _impute_one_feature(self, X_train = safe_indexing(X_filled[:, neighbor_feat_idx], ~missing_row_mask) y_train = safe_indexing(X_filled[:, feat_idx], - missing_row_mask) + ~missing_row_mask) estimator.fit(X_train, y_train) # get posterior samples From abc90d76d3d8f5fbf9eb5729eeb85aa5dd89469a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 19 Aug 2019 15:12:18 +0200 Subject: [PATCH 38/45] fix corner case --- sklearn/utils/__init__.py | 2 ++ sklearn/utils/tests/test_utils.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 5f384f763e2bb..a70d6ff7efbd4 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -247,6 +247,8 @@ def _check_key_type(key, superclass): if isinstance(key, superclass): return True if isinstance(key, slice): + if key.start is None and key.stop is None: + return False return (isinstance(key.start, (superclass, type(None))) and isinstance(key.stop, (superclass, type(None)))) if isinstance(key, list): diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 4e774d6df6884..9b14d8e07e530 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -203,6 +203,9 @@ def test_column_or_1d(): ('0', int, False), ([0, 1, 2], int, True), (['0', '1', '2'], int, False), + (slice(None, None), str, False), + (slice(None, None), bool, False), + (slice(None, None), int, False), (slice(0, 2), int, True), (np.array([0, 1, 2], dtype=np.int32), int, True), (np.array([0, 1, 2], dtype=np.int64), int, True), From b99b5a26f558cfe5012f0914cf7a9d7b326f6d1a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 20 Aug 2019 15:57:45 +0200 Subject: [PATCH 39/45] determine the dtype instead of checking it several times --- sklearn/compose/_column_transformer.py | 7 +- sklearn/utils/__init__.py | 128 +++++++++++++------------ sklearn/utils/tests/test_utils.py | 46 ++++----- 3 files changed, 90 insertions(+), 91 deletions(-) diff --git a/sklearn/compose/_column_transformer.py b/sklearn/compose/_column_transformer.py index ab5b4f941f4c7..fb46d07d172f6 100644 --- a/sklearn/compose/_column_transformer.py +++ b/sklearn/compose/_column_transformer.py @@ -20,7 +20,7 @@ from ..utils import Bunch from ..utils import safe_indexing from ..utils import _get_column_indices -from ..utils import _check_key_type +from ..utils import _determine_key_type from ..utils.metaestimators import _BaseComposition from ..utils.validation import check_array, check_is_fitted @@ -309,7 +309,8 @@ def _validate_remainder(self, X): # Make it possible to check for reordered named columns on transform if (hasattr(X, 'columns') and - any(_check_key_type(cols, str) for cols in self._columns)): + any(_determine_key_type(cols) == 'str' + for cols in self._columns)): self._df_columns = X.columns self._n_features = X.shape[1] @@ -755,6 +756,6 @@ def _is_negative_indexing(key): def is_neg(x): return isinstance(x, numbers.Integral) and x < 0 if isinstance(key, slice): return is_neg(key.start) or is_neg(key.stop) - elif _check_key_type(key, int): + elif _determine_key_type(key) == 'int': return np.any(np.asarray(key) < 0) return False diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index a70d6ff7efbd4..043d8b5ddfd21 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -181,94 +181,94 @@ def axis0_safe_slice(X, mask, len_mask): return np.zeros(shape=(0, X.shape[1])) -def _array_indexing(array, key, axis): +def _array_indexing(array, key, key_dtype, axis): """Index an array consistently across NumPy version.""" if np_version < (1, 12) or issparse(array): # FIXME: Remove the check for NumPy when using >= 1.12 # check if we have an boolean array-likes to make the proper indexing - key_array = np.asarray(key) - if np.issubdtype(key_array.dtype, np.bool_): - key = key_array + if key_dtype == 'bool': + key = np.asarray(key) return array[key] if axis == 0 else array[:, key] -def _pandas_indexing(X, key, axis): +def _pandas_indexing(X, key, key_dtype, axis): """Index a pandas dataframe or a series.""" - # check whether we should index with loc or iloc - if _check_key_type(key, int): - by_name = False - elif _check_key_type(key, str): - by_name = True - elif _check_key_type(key, bool): - # pandas boolean masks don't work with iloc, so take loc path - by_name = True - else: - raise ValueError("No valid specification of the columns. Only a " - "scalar, list or slice of all integers or all " - "strings, or boolean mask is allowed") - if hasattr(key, 'shape'): # Work-around for indexing with read-only key in pandas # FIXME: solved in pandas 0.25 key = np.asarray(key) key = key if key.flags.writeable else key.copy() - indexer = 'loc' if by_name else 'iloc' + # check whether we should index with loc or iloc + indexer = 'iloc' if key_dtype == 'int' else 'loc' return (getattr(X, indexer)[:, key] if axis else getattr(X, indexer)[key]) -def _list_indexing(X, key): +def _list_indexing(X, key, key_dtype): """Index a Python list.""" if np.isscalar(key) or isinstance(key, slice): # key is a slice or a scalar return X[key] - key_set = set(key) - if (len(key_set) == 2 and - all(isinstance(k, (bool, np.bool_)) for k in key_set)): + if key_dtype == 'bool': # key is a boolean array-like return list(compress(X, key)) # key is a integer array-like of key return [X[idx] for idx in key] -def _check_key_type(key, superclass): - """Check that scalar, list or slice is of a certain type. - - This is only used in _safe_indexing_column and _get_column_indices to check - if the ``key`` (column specification) is fully integer or fully - string-like. +def _determine_key_type(key): + """Determine the data type of key. Parameters ---------- - key : scalar, list, slice, array-like - The column specification to check. - superclass : int or str - The type for which to check the `key`. + key : scalar, slice or array-like + The key from which we want to infer the data type. + + Returns + ------- + dtype : {'int', 'str', 'bool'} + Returns the data type of key. """ - if isinstance(key, superclass): - return True + err_msg = ("No valid specification of the columns. Only a scalar, list or " + "slice of all integers or all strings, or boolean mask is " + "allowed") + + dtype_to_str = {int: 'int', str: 'str', bool: 'bool', np.bool_: 'bool'} + array_dtype_to_str = {'i': 'int', 'u': 'int', 'b': 'bool', 'O': 'str', + 'U': 'str', 'S': 'str'} + + if key is None: + return None + if isinstance(key, tuple(dtype_to_str.keys())): + try: + return dtype_to_str[type(key)] + except KeyError: + raise ValueError(err_msg) if isinstance(key, slice): if key.start is None and key.stop is None: - return False - return (isinstance(key.start, (superclass, type(None))) and - isinstance(key.stop, (superclass, type(None)))) + return None + key_start_type = _determine_key_type(key.start) + key_stop_type = _determine_key_type(key.stop) + if key_start_type is not None and key_stop_type is not None: + if key_start_type != key_stop_type: + raise ValueError(err_msg) + if key_start_type is not None: + return key_start_type + return key_stop_type if isinstance(key, list): unique_key = set(key) - all_superclass = all(isinstance(x, superclass) for x in unique_key) - if superclass != int: - return all_superclass - # bool is a subclass of int, therefore we should check specifically for - # any bool - any_bool = any(isinstance(x, bool) for x in unique_key) - return all_superclass and not any_bool + set_type = {_determine_key_type(elt) for elt in unique_key} + if not set_type: + return None + if len(set_type) != 1: + raise ValueError(err_msg) + set_type, = set_type + return set_type if hasattr(key, 'dtype'): - if superclass is int: - return key.dtype.kind == 'i' - elif superclass is bool: - return key.dtype.kind == 'b' - else: - # superclass = str - return key.dtype.kind in ('O', 'U', 'S') - return False + try: + return array_dtype_to_str[key.dtype.kind] + except KeyError: + raise ValueError(err_msg) + raise ValueError(err_msg) def safe_indexing(X, indices, axis=0): @@ -315,11 +315,11 @@ def safe_indexing(X, indices, axis=0): " column). Got {} instead.".format(axis) ) - if axis == 0 and _check_key_type(indices, str): + indices_dtype = _determine_key_type(indices) + + if axis == 0 and indices_dtype == 'str': raise ValueError( - "'axis=0' only support integer or boolean array-like, slice with " - "integer, or scalar integer as indices. Got {} instead." - .format(indices) + "String indexing is not supported with 'axis=0'" ) if axis == 1 and X.ndim != 2: @@ -329,18 +329,18 @@ def safe_indexing(X, indices, axis=0): "Got {} instead with {} dimension(s).".format(type(X), X.ndim) ) - if axis == 1 and _check_key_type(indices, str) and not hasattr(X, 'loc'): + if axis == 1 and indices_dtype == 'str' and not hasattr(X, 'loc'): raise ValueError( "Specifying the columns using strings is only supported for " "pandas DataFrames" ) if hasattr(X, "iloc"): - return _pandas_indexing(X, indices, axis=axis) + return _pandas_indexing(X, indices, indices_dtype, axis=axis) elif hasattr(X, "shape"): - return _array_indexing(X, indices, axis=axis) + return _array_indexing(X, indices, indices_dtype, axis=axis) else: - return _list_indexing(X, indices) + return _list_indexing(X, indices, indices_dtype) def _get_column_indices(X, key): @@ -351,10 +351,12 @@ def _get_column_indices(X, key): """ n_columns = X.shape[1] + key_dtype = _determine_key_type(key) + if isinstance(key, list) and not key: # we get an empty list return [] - elif (_check_key_type(key, int) or _check_key_type(key, bool)): + elif key_dtype in ('bool', 'int'): # Convert key into positive indexes try: idx = safe_indexing(np.arange(n_columns), key) @@ -364,7 +366,7 @@ def _get_column_indices(X, key): .format(n_columns - 1, n_columns) ) from e return np.atleast_1d(idx).tolist() - elif _check_key_type(key, str): + elif key_dtype == 'str': try: all_columns = list(X.columns) except AttributeError: diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 9b14d8e07e530..272bb4cc54dba 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -14,7 +14,7 @@ assert_raises_regex, assert_warns_message, assert_no_warnings) from sklearn.utils import check_random_state -from sklearn.utils import _check_key_type +from sklearn.utils import _determine_key_type from sklearn.utils import deprecated from sklearn.utils import _get_column_indices from sklearn.utils import resample @@ -198,30 +198,26 @@ def test_column_or_1d(): @pytest.mark.parametrize( - "key, clazz, is_expected_type", - [(0, int, True), - ('0', int, False), - ([0, 1, 2], int, True), - (['0', '1', '2'], int, False), - (slice(None, None), str, False), - (slice(None, None), bool, False), - (slice(None, None), int, False), - (slice(0, 2), int, True), - (np.array([0, 1, 2], dtype=np.int32), int, True), - (np.array([0, 1, 2], dtype=np.int64), int, True), - (np.array([0, 1, 2], dtype=np.uint8), int, False), - ([True, False], bool, True), - ([True, False], int, False), - (np.array([True, False]), bool, True), - (np.array([True, False]), int, False), - ('col_0', str, True), - (['col_0', 'col_1', 'col_2'], str, True), - (slice('begin', 'end'), str, True), - (np.array(['col_0', 'col_1', 'col_2']), str, True), - (np.array(['col_0', 'col_1', 'col_2'], dtype=object), str, True)] + "key, dtype", + [(0, 'int'), + ('0', 'str'), + ([0, 1, 2], 'int'), + (['0', '1', '2'], 'str'), + (slice(None, None), None), + (slice(0, 2), 'int'), + (np.array([0, 1, 2], dtype=np.int32), 'int'), + (np.array([0, 1, 2], dtype=np.int64), 'int'), + (np.array([0, 1, 2], dtype=np.uint8), 'int'), + ([True, False], 'bool'), + (np.array([True, False]), 'bool'), + ('col_0', 'str'), + (['col_0', 'col_1', 'col_2'], 'str'), + (slice('begin', 'end'), 'str'), + (np.array(['col_0', 'col_1', 'col_2']), 'str'), + (np.array(['col_0', 'col_1', 'col_2'], dtype=object), 'str')] ) -def test_check_key_type(key, clazz, is_expected_type): - assert _check_key_type(key, clazz) is is_expected_type +def test_determine_key_type(key, dtype): + assert _determine_key_type(key) == dtype def _convert_container(container, constructor_name, columns_name=None): @@ -448,7 +444,7 @@ def test_safe_indexing_1d_array_error(X_constructor): def test_safe_indexing_container_axis_0_unsupported_type(): indices = ["col_1", "col_2"] array = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] - err_msg = "'axis=0' only support integer or boolean array-like," + err_msg = "String indexing is not supported with 'axis=0'" with pytest.raises(ValueError, match=err_msg): safe_indexing(array, indices, axis=0) From 3bf36b0a8af9eb9f8b5b25ea471f3828d6ef95aa Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 20 Aug 2019 17:50:13 +0200 Subject: [PATCH 40/45] add for coverage --- sklearn/utils/tests/test_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 272bb4cc54dba..7c65a923de355 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -220,6 +220,11 @@ def test_determine_key_type(key, dtype): assert _determine_key_type(key) == dtype +def test_determine_key_type_error(): + with pytest.raises(ValueError, match="No valid specification of the"): + _determine_key_type(1.0) + + def _convert_container(container, constructor_name, columns_name=None): if constructor_name == 'list': return list(container) @@ -254,7 +259,7 @@ def test_safe_indexing_2d_container_axis_0(array_type, indices_type): @pytest.mark.parametrize("array_type", ["list", "array", "series"]) -@pytest.mark.parametrize("indices_type", ["list", "array", "series"]) +@pytest.mark.parametrize("indices_type", ["list", "array", "series", "slice"]) def test_safe_indexing_1d_container(array_type, indices_type): indices = [1, 2] if indices_type == 'slice' and isinstance(indices[1], int): From bd86ccd03da881b190f86e1b55d2f83bb3105ea4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 21 Aug 2019 10:54:16 +0200 Subject: [PATCH 41/45] iter --- sklearn/utils/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 043d8b5ddfd21..57485deb79da4 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -225,7 +225,7 @@ def _determine_key_type(key): Returns ------- - dtype : {'int', 'str', 'bool'} + dtype : {'int', 'str', 'bool', None} Returns the data type of key. """ err_msg = ("No valid specification of the columns. Only a scalar, list or " @@ -261,8 +261,7 @@ def _determine_key_type(key): return None if len(set_type) != 1: raise ValueError(err_msg) - set_type, = set_type - return set_type + return set_type.pop() if hasattr(key, 'dtype'): try: return array_dtype_to_str[key.dtype.kind] From bfb9fa287c38de5430fc8c7ee910e017eedd5cec Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 21 Aug 2019 11:07:30 +0200 Subject: [PATCH 42/45] itert --- sklearn/utils/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 57485deb79da4..ca857439b2968 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -256,12 +256,12 @@ def _determine_key_type(key): return key_stop_type if isinstance(key, list): unique_key = set(key) - set_type = {_determine_key_type(elt) for elt in unique_key} - if not set_type: + key_type = {_determine_key_type(elt) for elt in unique_key} + if not key_type: return None - if len(set_type) != 1: + if len(key_type) != 1: raise ValueError(err_msg) - return set_type.pop() + return key_type.pop() if hasattr(key, 'dtype'): try: return array_dtype_to_str[key.dtype.kind] From c433494cf878c19dd9ed83b40af9fd3241821d8a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 21 Aug 2019 11:09:37 +0200 Subject: [PATCH 43/45] bool type --- sklearn/utils/tests/test_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 7c65a923de355..363b77a44b5fc 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -201,6 +201,8 @@ def test_column_or_1d(): "key, dtype", [(0, 'int'), ('0', 'str'), + (True, 'bool'), + (np.bool_(True), 'bool'), ([0, 1, 2], 'int'), (['0', '1', '2'], 'str'), (slice(None, None), None), From 7994cc926af0572e43aa79ec2a3d85b320c1449a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 21 Aug 2019 13:47:01 +0200 Subject: [PATCH 44/45] joel comments --- sklearn/utils/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index ca857439b2968..cb5b7a2e8f9f5 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -199,8 +199,8 @@ def _pandas_indexing(X, key, key_dtype, axis): key = np.asarray(key) key = key if key.flags.writeable else key.copy() # check whether we should index with loc or iloc - indexer = 'iloc' if key_dtype == 'int' else 'loc' - return (getattr(X, indexer)[:, key] if axis else getattr(X, indexer)[key]) + indexer = X.iloc if key_dtype == 'int' else X.loc + return indexer[:, key] if axis else indexer[key] def _list_indexing(X, key, key_dtype): From 5e29cc61933e26c97d0ff3083c71d94166700887 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 21 Aug 2019 13:47:58 +0200 Subject: [PATCH 45/45] Update sklearn/utils/__init__.py Co-Authored-By: Joel Nothman --- sklearn/utils/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index cb5b7a2e8f9f5..c26fab41bfc93 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -182,7 +182,7 @@ def axis0_safe_slice(X, mask, len_mask): def _array_indexing(array, key, key_dtype, axis): - """Index an array consistently across NumPy version.""" + """Index an array or scipy.sparse consistently across NumPy version.""" if np_version < (1, 12) or issparse(array): # FIXME: Remove the check for NumPy when using >= 1.12 # check if we have an boolean array-likes to make the proper indexing