diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 1753bf9b404bb..f3590aa3a0174 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -656,6 +656,7 @@ Kernels: impute.SimpleImputer impute.MissingIndicator + impute.SamplingImputer .. _kernel_approximation_ref: diff --git a/doc/modules/impute.rst b/doc/modules/impute.rst index 0fd119857177b..30470256f2588 100644 --- a/doc/modules/impute.rst +++ b/doc/modules/impute.rst @@ -72,8 +72,33 @@ string values or pandas categoricals when using the ``'most_frequent'`` or ['b' 'y']] -:class:`SimpleImputer` can be used in a Pipeline as a way to build a composite -estimator that supports imputation. See :ref:`sphx_glr_auto_examples_plot_missing_values.py`. +Unlike the strategies of the :class:`SimpleImputer`, which are all deterministic, +the :class:`SamplingImputer` class provides a non deterministic strategy to perform +univariate feature imputation. Imputation is performed by sampling uniformly at +random from the non missing values. Therefore the imputed feature distribution +is asymptotically identical to the original distribution, preserving mean and +variance for example. The :class:`SamplingImputer` class supports +sparse and categorical data. It is used the same way as the :class:`SimpleImputer`:: + + >>> import numpy as np + >>> from sklearn.impute import SamplingImputer + >>> X = np.array([[1, -1], + ... [2, -2], + ... [np.nan, np.nan], + ... [np.nan, np.nan], + ... [np.nan, np.nan]]) + >>> imp = SamplingImputer(missing_values=np.nan, random_state=0) + >>> print(imp.fit_transform(X)) # doctest: +NORMALIZE_WHITESPACE + [[ 1. -1.] + [ 2. -2.] + [ 1. -2.] + [ 2. -1.] + [ 2. -1.]] + + +:class:`SimpleImputer` and :class:`SamplingImputer` can be used in a Pipeline as +a way to build a composite estimator that supports imputation. +See :ref:`sphx_glr_auto_examples_plot_missing_values.py`. .. _missing_indicator: diff --git a/doc/whats_new/v0.20.rst b/doc/whats_new/v0.20.rst index 6b199e9e2bac1..014eefb35234c 100644 --- a/doc/whats_new/v0.20.rst +++ b/doc/whats_new/v0.20.rst @@ -152,6 +152,10 @@ Preprocessing missing values. :issue:`8075` by :user:`Maniteja Nandana ` and :user:`Guillaume Lemaitre `. +- Added :class:`impute.SamplingImputer`, which is an univariate strategy for + imputing missing values by sampling uniformly at random from the non missing + values. :issue:`11368` by :user:`Jeremie du Boisberranger `. + - :class:`linear_model.SGDClassifier`, :class:`linear_model.SGDRegressor`, :class:`linear_model.PassiveAggressiveClassifier`, :class:`linear_model.PassiveAggressiveRegressor` and diff --git a/sklearn/impute.py b/sklearn/impute.py index e98c425d1b34f..2e87f8ab63995 100644 --- a/sklearn/impute.py +++ b/sklearn/impute.py @@ -2,6 +2,7 @@ # Authors: Nicolas Tresegnie # Sergey Feldman # License: BSD 3 clause +from __future__ import division import warnings import numbers @@ -15,8 +16,10 @@ from .utils import check_array from .utils.sparsefuncs import _get_median from .utils.validation import check_is_fitted +from .utils.validation import check_random_state from .utils.validation import FLOAT_DTYPES from .utils.fixes import _object_dtype_isnan +from .utils.fixes import _uniques_counts from .utils import is_scalar_nan from .externals import six @@ -27,6 +30,7 @@ __all__ = [ 'MissingIndicator', 'SimpleImputer', + 'SamplingImputer' ] @@ -127,16 +131,16 @@ class SimpleImputer(BaseEstimator, TransformerMixin): copy : boolean, optional (default=True) If True, a copy of X will be created. If False, imputation will - be done in-place whenever possible. Note that, in the following cases, - a new copy will always be made, even if `copy=False`: - - - If X is not an array of floating values; - - If X is encoded as a CSR matrix. + be done in-place whenever possible. Note that if X is sparse and not + encoded as a CSC matrix, a new copy will always be made, even + if ``copy=False``. Attributes ---------- statistics_ : array of shape (n_features,) - The imputation fill value for each feature. + The imputation fill value for each feature. For each feature i, + ``statistics_[i]`` is set to ``np.nan`` if the feature contains only + missing values. Examples -------- @@ -418,7 +422,7 @@ class MissingIndicator(BaseEstimator, TransformerMixin): The placeholder for the missing values. All occurrences of `missing_values` will be imputed. - features : str, optional + features : {"missing-only", "all"}, optional Whether the imputer mask should represent all or a subset of features. @@ -629,3 +633,257 @@ def fit_transform(self, X, y=None): """ return self.fit(X, y).transform(X) + + +class SamplingImputer(BaseEstimator, TransformerMixin): + """Imputation transformer for completing missing values. + + Impute each feature's missing values by sampling from the empirical + distribution. + + Read more in the :ref:`User Guide `. + + Parameters + ---------- + missing_values : number, string, np.nan (default) or None + The placeholder for the missing values. All occurrences of + `missing_values` will be imputed. + + verbose : integer, optional (default=0) + Controls the verbosity of the imputer. + + copy : boolean, optional (default=True) + If True, a copy of X will be created. If False, imputation will + be done in-place whenever possible. Note that if X is sparse and not + encoded as a CSC matrix, a new copy will always be made, even + if ``copy=False``. + + random_state : int, RandomState instance or None, optional (default=None) + The seed of the pseudo random number generator to use when sampling + fill values. + + - If int, random_state is the seed used by the random number generator; + - If RandomState instance, random_state is the random number generator; + - If None, the random number generator is the RandomState instance used + by ``np.random``. + + Attributes + ---------- + uniques_ : array of shape (n_features,) + For each feature i, ``uniques_[i]`` contains all the non-missing values + in that feature without repetitions. Set to None if the feature + contains only missing values. + + probas_ : array of shape (n_features,) + The probabilities associated with all the values in uniques_. For each + feature i, ``probas_[i]`` is set to None if the feature contains only + missing values or if there are no duplicates in the non-missing values. + + Examples + -------- + >>> import numpy as np + >>> from sklearn.impute import SamplingImputer + >>> imputer = SamplingImputer(random_state=1234) + >>> X = [[7, 2, 3], [4, np.nan, 6], [10, 5, 9]] + >>> print(imputer.fit_transform(X)) # doctest: +NORMALIZE_WHITESPACE + [[ 7. 2. 3.] + [ 4. 5. 6.] + [10. 5. 9.]] + >>> print(imputer.transform([[np.nan, 2, 3], + ... [4, np.nan, 6], + ... [10, np.nan, 9]])) + ... # doctest: +NORMALIZE_WHITESPACE + [[10. 2. 3.] + [ 4. 5. 6.] + [10. 2. 9.]] + + Notes + ----- + Columns which only contained missing values at `fit` are discarded upon + `transform`. + """ + + def __init__(self, missing_values=np.nan, + verbose=0, copy=True, random_state=None): + self.missing_values = missing_values + self.verbose = verbose + self.copy = copy + self.random_state = random_state + + def _validate_input(self, X): + if not is_scalar_nan(self.missing_values): + force_all_finite = True + else: + force_all_finite = "allow-nan" + + X = check_array(X, accept_sparse='csc', dtype=None, + force_all_finite=force_all_finite, copy=self.copy) + + if X.dtype.kind not in ("i", "u", "f", "O"): + raise ValueError("SamplingImputer does not support data with dtype" + " {0}. Please provide either a numeric array ( " + "with a floating point or integer dtype) or " + "categorical data represented either as an array " + "with integer dtype or an array of string values " + "with an object dtype.".format(X.dtype)) + + return X + + def fit(self, X, y=None): + """Fit the imputer on X. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape (n_samples, n_features) + Input data, where ``n_samples`` is the number of samples and + ``n_features`` is the number of features. + + Returns + ------- + self : SamplingImputer + """ + self._random_state_ = getattr(self, "random_state_", + check_random_state(self.random_state)) + + X = self._validate_input(X) + + if sparse.issparse(X): + # missing_values = 0 not allowed with sparse data as it would + # force densification + if self.missing_values == 0: + raise ValueError("Imputation not possible when missing_values" + " == 0 and input is sparse. Provide a dense " + "array instead.") + else: + self.uniques_, self.probas_ = \ + self._sparse_fit(X, self.missing_values) + else: + self.uniques_, self.probas_ = self._dense_fit(X, + self.missing_values) + + return self + + def _sparse_fit(self, X, missing_values): + """Fit the transformer on sparse data.""" + mask_data = _get_mask(X.data, missing_values) + n_implicit_zeros = X.shape[0] - np.diff(X.indptr) + + uniques = np.empty(X.shape[1], dtype=object) + probas = np.empty(X.shape[1], dtype=object) + + for i in range(X.shape[1]): + column = X.data[X.indptr[i]:X.indptr[i+1]] + mask_column = mask_data[X.indptr[i]:X.indptr[i+1]] + column = column[~mask_column] + + values, counts = _uniques_counts(column) + + # count implicit zeros + if n_implicit_zeros[i] > 0: + if 0 in values: + counts[values == 0] += n_implicit_zeros[i] + else: + values = np.append(values, 0) + counts = np.append(counts, n_implicit_zeros[i]) + + if values.size > 0: + uniques[i] = values + if values.size == X.shape[0]: + # Avoids doubling the memory usage when dealing with + # continuous-valued feature which rarely contain duplicates + probas[i] = None + else: + probas[i] = counts / counts.sum() + else: + uniques[i] = None + probas[i] = None + + return uniques, probas + + def _dense_fit(self, X, missing_values): + """Fit the transformer on dense data.""" + mask = _get_mask(X, missing_values) + + uniques = np.empty(X.shape[1], dtype=object) + probas = np.empty(X.shape[1], dtype=object) + + for i in range(X.shape[1]): + column = X[~mask[:, i], i] + if column.size > 0: + uniques[i], counts = _uniques_counts(column) + if uniques[i].size == column.size: + # Avoids doubling the memory usage when dealing with + # continuous-valued feature which rarely contain duplicates + probas[i] = None + else: + probas[i] = counts / counts.sum() + else: + uniques[i] = None + probas[i] = None + + return uniques, probas + + def transform(self, X): + """Impute all missing values in X. + + Parameters + ---------- + X : {array-like, sparse matrix}, shape = [n_samples, n_features] + The input data to complete. + """ + check_is_fitted(self, ("uniques_", "probas_")) + + X = self._validate_input(X) + + uniques = self.uniques_ + probas = self.probas_ + + if X.shape[1] != uniques.shape[0]: + raise ValueError("X has %d features per sample, expected %d" + % (X.shape[1], self.uniques_.shape[0])) + + # Delete the invalid columns + valid_mask = np.asarray([x is not None for x in uniques]) + valid_indexes = np.flatnonzero(valid_mask) + + if valid_mask.all(): + valid_mask = Ellipsis + else: + missing = np.arange(X.shape[1])[~valid_mask] + if self.verbose: + warnings.warn("Deleting features without " + "observed values: %s" % missing) + X = X[:, valid_indexes] + + valid_uniques = uniques[valid_mask] + valid_probas = probas[valid_mask] + + # Do actual imputation + if sparse.issparse(X): + if self.missing_values == 0: + raise ValueError("Imputation not possible when missing_values" + " == 0 and input is sparse. Provide a dense " + "array instead.") + else: + mask_data = _get_mask(X.data, self.missing_values) + for i in range(X.shape[1]): + column = X.data[X.indptr[i]:X.indptr[i+1]] + mask_column = mask_data[X.indptr[i]:X.indptr[i+1]] + n_missing = mask_column.sum() + values = self._random_state_.choice(valid_uniques[i], + n_missing, + p=valid_probas[i]) + column[mask_column] = values + # in case some missing values are imputed with 0 + X.eliminate_zeros() + + else: + mask = _get_mask(X, self.missing_values) + n_missing = np.sum(mask, axis=0) + for i in range(n_missing.shape[0]): + values = self._random_state_.choice(valid_uniques[i], + n_missing[i], + p=valid_probas[i]) + X[mask[:, i], i] = values + + return X diff --git a/sklearn/tests/test_impute.py b/sklearn/tests/test_impute.py index acd9117e9f7df..7cf635d38e793 100644 --- a/sklearn/tests/test_impute.py +++ b/sklearn/tests/test_impute.py @@ -9,10 +9,9 @@ from sklearn.utils.testing import assert_allclose_dense_sparse from sklearn.utils.testing import assert_array_equal from sklearn.utils.testing import assert_array_almost_equal -from sklearn.utils.testing import assert_false - +from sklearn.impute import SimpleImputer, SamplingImputer from sklearn.impute import MissingIndicator -from sklearn.impute import SimpleImputer +from sklearn.impute import _get_mask from sklearn.pipeline import Pipeline from sklearn.model_selection import GridSearchCV from sklearn import tree @@ -57,21 +56,31 @@ def _check_statistics(X, X_true, assert_ae(X_trans, X_true, err_msg=err_msg.format(True)) -def test_imputation_shape(): - # Verify the shapes of the imputed matrix for different strategies. +@pytest.mark.parametrize("imputer_constructor, params", + [(SimpleImputer, {'strategy': "mean"}), + (SimpleImputer, {'strategy': "median"}), + (SimpleImputer, {'strategy': "most_frequent"}), + (SimpleImputer, {'strategy': "constant"}), + (SamplingImputer, {})]) +def test_imputation_shape(imputer_constructor, params): + # Verify the shapes of the imputed matrix for the SamplingImputer and for + # the SimpleImputer with the different strategies X = np.random.randn(10, 2) X[::2] = np.nan - for strategy in ['mean', 'median', 'most_frequent', "constant"]: - imputer = SimpleImputer(strategy=strategy) - X_imputed = imputer.fit_transform(sparse.csr_matrix(X)) - assert X_imputed.shape == (10, 2) - X_imputed = imputer.fit_transform(X) - assert X_imputed.shape == (10, 2) + imputer = imputer_constructor() + imputer.set_params(**params) + + X_imputed = imputer.fit_transform(sparse.csr_matrix(X)) + assert X_imputed.shape == (10, 2) + X_imputed = imputer.fit_transform(X) + assert X_imputed.shape == (10, 2) + X_imputed.shape == (10, 2) @pytest.mark.parametrize("strategy", ["const", 101, None]) def test_imputation_error_invalid_strategy(strategy): + # verify that error is raised when strategy is not an allowed one X = np.ones((3, 5)) X[0, 0] = np.nan @@ -80,25 +89,39 @@ def test_imputation_error_invalid_strategy(strategy): imputer.fit_transform(X) -@pytest.mark.parametrize("strategy", ["mean", "median", "most_frequent"]) -def test_imputation_deletion_warning(strategy): +@pytest.mark.parametrize("imputer_constructor, params", + [(SimpleImputer, {'strategy': "mean"}), + (SimpleImputer, {'strategy': "median"}), + (SimpleImputer, {'strategy': "most_frequent"}), + (SamplingImputer, {})]) +def test_imputation_deletion_warning(imputer_constructor, params): + # verify that warning is raised when deleting feature when using + # SimpleImputer or SamplingImputer X = np.ones((3, 5)) X[:, 0] = np.nan + imputer = imputer_constructor(verbose=True) + imputer.set_params(**params) + with pytest.warns(UserWarning, match="Deleting"): - imputer = SimpleImputer(strategy=strategy, verbose=True) imputer.fit_transform(X) -@pytest.mark.parametrize("strategy", ["mean", "median", - "most_frequent", "constant"]) -def test_imputation_error_sparse_0(strategy): +@pytest.mark.parametrize("imputer_constructor, params", + [(SimpleImputer, {'strategy': "mean"}), + (SimpleImputer, {'strategy': "median"}), + (SimpleImputer, {'strategy': "most_frequent"}), + (SimpleImputer, {'strategy': "constant"}), + (SamplingImputer, {})]) +def test_imputation_error_sparse_0(imputer_constructor, params): # check that error are raised when missing_values = 0 and input is sparse + # when using SimpleImputer or SamplingImputer X = np.ones((3, 5)) X[0] = 0 X = sparse.csc_matrix(X) - imputer = SimpleImputer(strategy=strategy, missing_values=0) + imputer = imputer_constructor(missing_values=0) + imputer.set_params(**params) with pytest.raises(ValueError, match="Provide a dense array"): imputer.fit(X) @@ -453,50 +476,56 @@ def test_imputation_pipeline_grid_search(): gs.fit(X, Y) -def test_imputation_copy(): - # Test imputation with copy - X_orig = sparse_random_matrix(5, 5, density=0.75, random_state=0) +X_orig = sparse_random_matrix(5, 5, density=0.75, random_state=0) - # copy=True, dense => copy - X = X_orig.copy().toarray() - imputer = SimpleImputer(missing_values=0, strategy="mean", copy=True) - Xt = imputer.fit(X).transform(X) - Xt[0, 0] = -1 - assert_false(np.all(X == Xt)) +@pytest.mark.parametrize( + "X_in, imputer_constructor, params, expect_copy", + [(X_orig.toarray(), SimpleImputer, + {'missing_values': 0, 'copy': True}, True), + (X_orig.toarray(), SamplingImputer, + {'missing_values': 0, 'copy': True}, True), + (X_orig, SimpleImputer, + {'missing_values': X_orig.data[0], 'copy': True}, True), + (X_orig, SamplingImputer, + {'missing_values': X_orig.data[0], 'copy': True}, True), + (X_orig.toarray(), SimpleImputer, + {'missing_values': 0, 'copy': False}, False), + (X_orig.toarray(), SamplingImputer, + {'missing_values': 0, 'copy': False}, False), + (X_orig.tocsc(), SimpleImputer, + {'missing_values': X_orig.data[0], 'copy': False}, False), + (X_orig.tocsc(), SamplingImputer, + {'missing_values': X_orig.data[0], 'copy': False}, False), + (X_orig, SimpleImputer, + {'missing_values': X_orig.data[0], 'copy': False}, True), + (X_orig, SamplingImputer, + {'missing_values': X_orig.data[0], 'copy': False}, True)]) +def test_imputation_copy(X_in, imputer_constructor, params, expect_copy): + # Test imputation with copy + # copy=True, dense => copy # copy=True, sparse csr => copy - X = X_orig.copy() - imputer = SimpleImputer(missing_values=X.data[0], strategy="mean", - copy=True) - Xt = imputer.fit(X).transform(X) - Xt.data[0] = -1 - assert_false(np.all(X.data == Xt.data)) - # copy=False, dense => no copy - X = X_orig.copy().toarray() - imputer = SimpleImputer(missing_values=0, strategy="mean", copy=False) - Xt = imputer.fit(X).transform(X) - Xt[0, 0] = -1 - assert_array_almost_equal(X, Xt) - # copy=False, sparse csc => no copy - X = X_orig.copy().tocsc() - imputer = SimpleImputer(missing_values=X.data[0], strategy="mean", - copy=False) - Xt = imputer.fit(X).transform(X) - Xt.data[0] = -1 - assert_array_almost_equal(X.data, Xt.data) - # copy=False, sparse csr => copy - X = X_orig.copy() - imputer = SimpleImputer(missing_values=X.data[0], strategy="mean", - copy=False) - Xt = imputer.fit(X).transform(X) - Xt.data[0] = -1 - assert_false(np.all(X.data == Xt.data)) + X = X_in.copy() - # Note: If X is sparse and if missing_values=0, then a (dense) copy of X is - # made, even if copy=False. + imputer = imputer_constructor() + imputer.set_params(**params) + + Xt = imputer.fit(X).transform(X) + if sparse.issparse(X_in): + Xt.data[0] = -1 + if expect_copy: + assert X is not Xt + else: + assert X is Xt + else: + Xt[0, 0] = -1 + if expect_copy: + assert X is not Xt + else: + assert X is Xt @pytest.mark.parametrize( @@ -634,3 +663,93 @@ def test_inconsistent_dtype_X_missing_values(imputer_constructor, with pytest.raises(ValueError, match=err_msg): imputer.fit_transform(X) + + +@pytest.mark.filterwarnings('ignore: in the future, full') +@pytest.mark.parametrize("X_value, dtype, missing_value", + [(1, int, -1), + (1, None, -1), + ("a", object, "NaN"), + ("a", object, np.nan), + ("a", object, None), + (1.0, float, 0), + (1.0, None, np.nan)]) +def test_sampling_deterministic(X_value, dtype, missing_value): + # test SamplingImputer on know output + X = np.full((10, 10), X_value, dtype=dtype) + X[:, 0] = missing_value + X[::2, ::2] = missing_value + + X_true = np.full((10, 9), X_value, dtype=dtype) + + imputer = SamplingImputer(missing_values=missing_value) + + X_trans = imputer.fit_transform(X) + + assert_array_equal(X_true, X_trans) + + +@pytest.mark.filterwarnings('ignore: in the future, full') +@pytest.mark.parametrize("dtype", [str, np.dtype('U'), np.dtype('S')]) +def test_sampling_error_invalid_type(dtype): + # Assert error are raised on invalid types + X = np.array([ + [np.nan, np.nan, "a", "f"], + [np.nan, "c", np.nan, "d"], + [np.nan, "b", "d", np.nan], + [np.nan, "c", "d", "h"], + ], dtype=object) + + imputer = SamplingImputer() + + err_msg = "SamplingImputer does not support data" + with pytest.raises(ValueError, match=err_msg): + imputer.fit(X.astype(dtype=dtype)) + + imputer.fit(X) + + err_msg = "SamplingImputer does not support data" + with pytest.raises(ValueError, match=err_msg): + imputer.transform(X.astype(dtype=dtype)) + + +@pytest.mark.filterwarnings('ignore: in the future, full') +def test_sampling_preserved_statistics(): + # check that: - filled values are drawn only within non-missing values + # - different random_states give different imputations + # - values are drawn uniformly at random + X = np.random.rand(20).reshape(-1, 1) + X[::2] = np.nan + + uniques = np.unique(X) + uniques = uniques[~_get_mask(uniques, np.nan)] + + imputer = SamplingImputer() + Xts = [] + for i in range(100): + Xt = imputer.set_params(random_state=i).fit_transform(X) + assert_array_equal(uniques, np.unique(Xt)) + Xts.append(Xt) + + tests = np.full(100, True) + for i in range(100): + tests[i] = np.allclose(Xts[i], Xts[i-1]) + assert not np.all(tests) + + assert np.mean(np.concatenate(Xts)) == pytest.approx(np.nanmean(X), + rel=1e-2) + + assert np.std(np.concatenate(Xts)) == pytest.approx(np.nanstd(X), + rel=1e-2) + + +def test_sampling_transform_vs_fir_transform(): + X = np.random.random_sample((10, 10)) + X[::2, 1::2] = np.nan + + imputer = SamplingImputer(random_state=0) + + Xt1 = imputer.fit(X).transform(X) + Xt2 = imputer.fit_transform(X) + + assert_array_almost_equal(Xt1, Xt2) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index ab09229c358c4..23a46e2f11fca 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -77,15 +77,20 @@ 'RANSACRegressor', 'RadiusNeighborsRegressor', 'RandomForestRegressor', 'Ridge', 'RidgeCV'] -ALLOW_NAN = ['Imputer', 'SimpleImputer', 'MissingIndicator', +ALLOW_NAN = ['Imputer', 'SimpleImputer', 'MissingIndicator', 'SamplingImputer', 'MaxAbsScaler', 'MinMaxScaler', 'RobustScaler', 'StandardScaler', 'PowerTransformer', 'QuantileTransformer'] +ALLOW_NON_NUMERIC = ['SimpleImputer', 'SamplingImputer'] + def _yield_non_meta_checks(name, estimator): yield check_estimators_dtypes yield check_fit_score_takes_y - yield check_dtype_object + + if name not in ALLOW_NON_NUMERIC: + yield check_dtype_object + yield check_sample_weights_pandas_series yield check_sample_weights_list yield check_sample_weights_invariance diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index 12ac3ae8e55e2..db6a0d662046c 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -323,3 +323,16 @@ def _object_dtype_isnan(X): from collections import Iterable as _Iterable # noqa from collections import Mapping as _Mapping # noqa from collections import Sized as _Sized # noqa + + +# backport for missing return_counts parameter in numpy.unique for numpy +# versions < 1.9 + +if np_version < (1, 9): + def _uniques_counts(array): + unique, idx = np.unique(array, return_inverse=True) + counts = np.bincount(idx) + return unique, counts +else: + def _uniques_counts(array): + return np.unique(array, return_counts=True)