diff --git a/doc/conftest.py b/doc/conftest.py index eacd469f2e52f..0e8f97e589264 100644 --- a/doc/conftest.py +++ b/doc/conftest.py @@ -50,6 +50,13 @@ def setup_compose(): raise SkipTest("Skipping compose.rst, pandas not installed") +def setup_metrics(): + try: + import pandas # noqa + except ImportError: + raise SkipTest("Skipping metrics.rst, pandas not installed") + + def setup_impute(): try: import pandas # noqa @@ -82,6 +89,8 @@ def pytest_runtest_setup(item): setup_working_with_text_data() elif fname.endswith('modules/compose.rst') or is_index: setup_compose() + elif fname.endswith('modules/metrics.rst') or is_index: + setup_metrics() elif IS_PYPY and fname.endswith('modules/feature_extraction.rst'): raise SkipTest('FeatureHasher is not compatible with PyPy') elif fname.endswith('modules/impute.rst'): diff --git a/doc/modules/classes.rst b/doc/modules/classes.rst index 35fa24ac9a846..a0083cdd4d682 100644 --- a/doc/modules/classes.rst +++ b/doc/modules/classes.rst @@ -1067,6 +1067,7 @@ See the :ref:`metrics` section of the user guide for further details. metrics.pairwise.cosine_distances metrics.pairwise.distance_metrics metrics.pairwise.euclidean_distances + metrics.pairwise.gower_distances metrics.pairwise.haversine_distances metrics.pairwise.kernel_metrics metrics.pairwise.laplacian_kernel diff --git a/doc/modules/metrics.rst b/doc/modules/metrics.rst index a5ef07e196ef6..c0bc4f4a5f672 100644 --- a/doc/modules/metrics.rst +++ b/doc/modules/metrics.rst @@ -93,6 +93,61 @@ is equivalent to :func:`linear_kernel`, only slower.) Information Retrieval. Cambridge University Press. https://nlp.stanford.edu/IR-book/html/htmledition/the-vector-space-model-for-scoring-1.html +.. _gower_distances: + +Gower distances +----------------- + +The function :func:`~sklearn.metrics.pairwise.gower_distances` computes the +distances between the observations in X and Y, that may contain combinations of +numerical, boolean, or categorical attributes, using an implementation of Gower +Similarity. + +.. math:: + + g(\mathbf{x}, \mathbf{y}) = \frac{\sum_i(s(x_i, y_i))}{|\{i| x_i \neq \text{missing} \land y_i \neq \text{missing}\}|} + +Where: + +:math:`x, y` : array_like of shape (n_features,) are the observations to be compared. + +:math:`s(x_i, y_i)` : Calculates the distance as: + + - :math:`s(x_i, y_i) := 0`, if either :math:`x_i` or :math:`y_i` are missing. + - :math:`s(x_i, y_i) := \text{int}(x_i == y_i)`, if :math:`i` represents a + boolean or categorical attribute. + - :math:`s(x_i, y_i) := abs(x_i - y_i)`, if :math:`i` represents a numerical + attribute. + + +The Gower formula combines a Manhattan (L1) distance for numeric features +with Hamming distance for categorical features to obtain a general coefficient +for categorical and numeric data. + +The :func:`gower_distances` function expects the user to specify the +categorical features, otherwise it will assume all features are numerical. If +the data is a `pandas.DataFrame`, you can use +:func:`~sklearn.compose.make_column_selector` to select features:: + + >>> import pandas as pd + >>> from sklearn.compose import make_column_selector as selector + >>> from sklearn.metrics.pairwise import gower_distances + >>> X = pd.DataFrame( + ... {'city': ['London', 'London', 'Paris', 'Sallisaw'], + ... 'expert_rating': [5, 3, 4, 5], + ... 'user_rating': [4, 5, 4, 3]}) + >>> gower_distances(X, categorical_features=selector(dtype_include=object)) + array([[0. , 0.5 , 0.5 , 0.5 ], + [0.5 , 0. , 0.6666..., 1. ], + [0.5 , 0.6666..., 0. , 0.6666...], + [0.5 , 1. , 0.6666..., 0. ]]) + +.. topic:: References: + + * Gower, J.C., 1971, A General Coefficient of Similarity and Some of Its + Properties, Biometrics, Vol. 27, No. 4. (Dec., 1971), pp. 857-871. + http://members.cbio.mines-paristech.fr/~jvert/svn/bibli/local/Gower1971general.pdf + .. _linear_kernel: Linear kernel @@ -165,14 +220,14 @@ the kernel is known as the Gaussian kernel of variance :math:`\sigma^2`. Laplacian kernel ---------------- -The function :func:`laplacian_kernel` is a variant on the radial basis +The function :func:`laplacian_kernel` is a variant on the radial basis function kernel defined as: .. math:: k(x, y) = \exp( -\gamma \| x-y \|_1) -where ``x`` and ``y`` are the input vectors and :math:`\|x-y\|_1` is the +where ``x`` and ``y`` are the input vectors and :math:`\|x-y\|_1` is the Manhattan distance between the input vectors. It has proven useful in ML applied to noiseless data. @@ -229,4 +284,3 @@ The chi squared kernel is most commonly used on histograms (bags) of visual word categories: A comprehensive study International Journal of Computer Vision 2007 https://research.microsoft.com/en-us/um/people/manik/projects/trade-off/papers/ZhangIJCV06.pdf - diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 2424c84394e2b..3a42be1d88e95 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -21,10 +21,13 @@ from ..utils.validation import _num_samples from ..utils.validation import check_non_negative +from ..utils.validation import check_consistent_length from ..utils import check_array from ..utils import gen_even_slices from ..utils import gen_batches, get_chunk_n_rows from ..utils import is_scalar_nan +from ..utils import _safe_indexing +from ..utils import _get_column_indices from ..utils.extmath import row_norms, safe_sparse_dot from ..preprocessing import normalize from ..utils._mask import _get_mask @@ -32,6 +35,8 @@ from ._pairwise_fast import _chi2_kernel_fast, _sparse_manhattan from ..exceptions import DataConversionWarning +from ..utils.fixes import _object_dtype_isnan +from ..preprocessing import MinMaxScaler # Utility Functions @@ -552,7 +557,7 @@ def pairwise_distances_argmin_min(X, Y, *, axis=1, metric="euclidean", Valid values for metric are: - from scikit-learn: ['cityblock', 'cosine', 'euclidean', 'l1', 'l2', - 'manhattan'] + 'manhattan', 'gower'] - from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev', 'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski', @@ -641,7 +646,7 @@ def pairwise_distances_argmin(X, Y, *, axis=1, metric="euclidean", Valid values for metric are: - from scikit-learn: ['cityblock', 'cosine', 'euclidean', 'l1', 'l2', - 'manhattan'] + 'manhattan', 'gower'] - from scipy.spatial.distance: ['braycurtis', 'canberra', 'chebyshev', 'correlation', 'dice', 'hamming', 'jaccard', 'kulsinski', @@ -839,6 +844,173 @@ def cosine_distances(X, Y=None): return S +def _split_categorical_numerical(X, categorical_features): + # the following bit is done before check_pairwise_array to avoid converting + # numerical data to object dtype. First we split the data into categorical + # and numerical, then we do check_array + + if X is None: + return None, None + + # TODO: this should be more like check_array(..., accept_pandas=True) + if not hasattr(X, "shape"): + X = check_array(X, dtype=np.object, force_all_finite=False) + + cols = categorical_features + if cols is None: + cols = [] + + col_idx = _get_column_indices(X, cols) + X_cat = _safe_indexing(X, col_idx, axis=1) + X_num = _safe_indexing(X, col_idx, axis=1, complement=True) + + return X_cat, X_num + + +def gower_distances(X, Y=None, categorical_features=None, scale=True, + min_values=None, scale_factor=None): + """Compute the distances between the observations in X and Y, + that may contain mixed types of data, using an implementation + of Gower formula. + + Parameters + ---------- + X : {array-like, pandas.DataFrame} of shape (n_samples, n_features) + + Y : {array-like, pandas.DataFrame} of shape (n_samples, n_features), \ + default=None + + categorical_features : array-like of str, array-like of int, \ + array-like of bool, slice or callable, default=None + Indexes the data on its second axis. Integers are interpreted as + positional columns, while strings can reference DataFrame columns + by name. + A callable is passed the input data `X` and can return any of the + above. To select multiple columns by name or dtype, you can use + :obj:`~sklearn.compose.make_column_selector`. + + By default all non-numeric columns are considered categorical. + + scale : bool, default=True + Indicates if the numerical columns should be scaled to [0, 1]. + If ``False``, the numerical columns are assumed to be already scaled. + The scaling factors, _i.e._ ``min_values`` and ``scale_factor``, are + taken from ``X``. If ``X`` and ``Y`` are both provided, ``min_values`` + and ``scale_factor`` have to be provided as well. + + min_values : ndarray of shape (n_features,), default=None + Per feature adjustment for minimum. Equivalent to + ``min_values - X.min(axis=0) * scale_factor`` + If provided, ``scale_factor`` should be provided as well. + Only relevant if ``scale=True``. + + scale_factor : ndarray of shape (n_features,), default=None + Per feature relative scaling of the data. Equivalent to + ``(max_values - min_values) / (X.max(axis=0) - X.min(axis=0))`` + If provided, ``min_values`` should be provided as well. + Only relevant if ``scale=True``. + + Returns + ------- + distances : ndarray of shape (n_samples_X, n_samples_Y) + + References + ---------- + Gower, J.C., 1971, A General Coefficient of Similarity and Some of Its + Properties. + + Notes + ----- + Categorical ordinal attributes should be treated as numeric for the purpose + of Gower similarity. + + Current implementation does not support sparse matrices. + + This implementation modifies the Gower's original similarity measure in + the sense that this implementation returns 1-S. + """ + def _nanmanhatan(x, y): + return np.nansum(np.abs(x - y)) + + def _non_nans(x, y): + return len(x) - np.sum(_object_dtype_isnan(x) | _object_dtype_isnan(y)) + + def _nanhamming(x, y): + return np.sum(x != y) - np.sum( + _object_dtype_isnan(x) | _object_dtype_isnan(y)) + + if issparse(X) or issparse(Y): + raise TypeError("Gower distance does not support sparse matrices") + + if X is None or len(X) == 0: + raise ValueError("X can not be None or empty") + + if scale: + if (scale_factor is None) != (min_values is None): + raise ValueError("min_value and scale_factor should be provided " + "together.") + + # scale_factor and min_values are either both None or not at this point + if X is not Y and Y is not None and scale_factor is None and scale: + raise ValueError("`scaling_factor` and `min_values` must be provided " + "when `Y` is provided and `scale=True`") + + if callable(categorical_features): + cols = categorical_features(X) + else: + cols = categorical_features + X_cat, X_num = _split_categorical_numerical(X, categorical_features=cols) + Y_cat, Y_num = _split_categorical_numerical(Y, categorical_features=cols) + + if min_values is not None: + min_values = np.asarray(min_values) + scale_factor = np.asarray(scale_factor) + check_consistent_length(min_values, scale_factor, + np.ndarray(shape=(X_num.shape[1], 0))) + + if X_num.shape[1]: + X_num, Y_num = check_pairwise_arrays(X_num, Y_num, precomputed=False, + dtype=float, + force_all_finite=False) + if scale: + scale_data = X_num if Y_num is X_num else np.vstack((X_num, Y_num)) + if scale_factor is None: + trs = MinMaxScaler().fit(scale_data) + else: + trs = MinMaxScaler() + trs.scale_ = scale_factor + trs.min_ = min_values + X_num = trs.transform(X_num) + Y_num = trs.transform(Y_num) + + nan_manhatan = distance.cdist(X_num, Y_num, _nanmanhatan) + # nan_manhatan = np.nansum(np.abs(X_num - Y_num)) + valid_num = distance.cdist(X_num, Y_num, _non_nans) + else: + nan_manhatan = valid_num = None + + if X_cat.shape[1]: + X_cat, Y_cat = check_pairwise_arrays(X_cat, Y_cat, precomputed=False, + dtype=np.object, + force_all_finite=False) + nan_hamming = distance.cdist(X_cat, Y_cat, _nanhamming) + valid_cat = distance.cdist(X_cat, Y_cat, _non_nans) + else: + nan_hamming = valid_cat = None + + # based on whether there are categorical and/or numerical data present, + # we compute the distance metric + # Division by zero and nans warnings are ignored since they are expected + with np.errstate(divide='ignore', invalid='ignore'): + if valid_num is not None and valid_cat is not None: + D = (nan_manhatan + nan_hamming) / (valid_num + valid_cat) + elif valid_num is not None: + D = nan_manhatan / valid_num + else: + D = nan_hamming / valid_cat + return D + + # Paired distances def paired_euclidean_distances(X, Y): """ @@ -915,7 +1087,7 @@ def paired_cosine_distances(X, Y): 'l2': paired_euclidean_distances, 'l1': paired_manhattan_distances, 'manhattan': paired_manhattan_distances, - 'cityblock': paired_manhattan_distances} + 'cityblock': paired_manhattan_distances, } @_deprecate_positional_args @@ -1309,6 +1481,7 @@ def chi2_kernel(X, Y=None, gamma=1.): 'l2': euclidean_distances, 'l1': manhattan_distances, 'manhattan': manhattan_distances, + 'gower': gower_distances, 'precomputed': None, # HACK: precomputed is always allowed, never called 'nan_euclidean': nan_euclidean_distances, } @@ -1333,6 +1506,7 @@ def distance_metrics(): 'l1' metrics.pairwise.manhattan_distances 'l2' metrics.pairwise.euclidean_distances 'manhattan' metrics.pairwise.manhattan_distances + 'gower' metrics.pairwise.gower_distances 'nan_euclidean' metrics.pairwise.nan_euclidean_distances =============== ======================================== @@ -1411,7 +1585,7 @@ def _pairwise_callable(X, Y, metric, force_all_finite=True, **kwds): 'mahalanobis', 'matching', 'minkowski', 'rogerstanimoto', 'russellrao', 'seuclidean', 'sokalmichener', 'sokalsneath', 'sqeuclidean', 'yule', "wminkowski", - 'nan_euclidean', 'haversine'] + 'nan_euclidean', 'haversine', 'gower'] _NAN_METRICS = ['nan_euclidean'] @@ -1440,6 +1614,40 @@ def _check_chunk_size(reduced, chunk_size): def _precompute_metric_params(X, Y, metric=None, **kwds): """Precompute data-derived metric parameters if not provided """ + if metric == 'gower': + categorical_features = kwds.get('categorical_features', None) + + if callable(categorical_features): + cols = categorical_features(X) + else: + cols = categorical_features + _, X_num = _split_categorical_numerical(X, cols) + + scale = kwds.get('scale', True) + if not scale: + return {'min_values': None, 'scale_factor': None, 'scale': False, + 'categorical_features': cols} + + scale_factor = kwds.get('scale_factor', None) + min_values = kwds.get('min_values', None) + if (scale_factor is None) != (min_values is None): + raise ValueError("min_value and scale_factor should be provided " + "together.") + + if min_values is None and (X is Y or Y is None): + trs = MinMaxScaler().fit(X_num) + min_values = trs.min_ + scale_factor = trs.scale_ + elif min_values is None: + raise ValueError("`scaling_factor` and `min_values` must be " + " provided when `Y` is provided and `scale=True`." + ) + + return {'min_values': min_values, + 'scale_factor': scale_factor, + 'scale': True, + 'categorical_features': cols} + if metric == "seuclidean" and 'V' not in kwds: if X is Y: V = np.var(X, axis=0, ddof=1) @@ -1744,6 +1952,17 @@ def pairwise_distances(X, Y=None, metric="euclidean", *, n_jobs=None, check_non_negative(X, whom=whom) return X elif metric in PAIRWISE_DISTANCE_FUNCTIONS: + if metric == 'gower': + """ + # These convertions are necessary for matrices with string values + if not isinstance(X, np.ndarray): + X = np.asarray(X, dtype=np.object) + if Y is not None and not isinstance(Y, np.ndarray): + Y = np.asarray(Y, dtype=np.object) + """ + params = _precompute_metric_params(X, Y, metric=metric, **kwds) + kwds.update(**params) + func = PAIRWISE_DISTANCE_FUNCTIONS[metric] elif callable(metric): func = partial(_pairwise_callable, metric=metric, diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index d7a96de12c9e3..b26572300c960 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -29,6 +29,7 @@ from sklearn.metrics.pairwise import sigmoid_kernel from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_distances +from sklearn.metrics.pairwise import gower_distances from sklearn.metrics.pairwise import pairwise_distances from sklearn.metrics.pairwise import pairwise_distances_chunked from sklearn.metrics.pairwise import pairwise_distances_argmin_min @@ -44,8 +45,11 @@ from sklearn.metrics.pairwise import paired_euclidean_distances from sklearn.metrics.pairwise import paired_manhattan_distances from sklearn.metrics.pairwise import _euclidean_distances_upcast -from sklearn.preprocessing import normalize +from sklearn.preprocessing import normalize, minmax_scale +from sklearn.preprocessing import OrdinalEncoder, MinMaxScaler from sklearn.exceptions import DataConversionWarning +from sklearn.compose import make_column_selector +from sklearn.utils.validation import check_random_state def test_pairwise_distances(): @@ -615,44 +619,37 @@ def test_pairwise_distances_chunked(): next(gen) -@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix], - ids=["dense", "sparse"]) -@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix], - ids=["dense", "sparse"]) -def test_euclidean_distances_known_result(x_array_constr, y_array_constr): - # Check the pairwise Euclidean distances computation on known result - X = x_array_constr([[0]]) - Y = y_array_constr([[1], [2]]) +def test_euclidean_distances(): + # Check the pairwise Euclidean distances computation + X = [[0]] + Y = [[1], [2]] D = euclidean_distances(X, Y) - assert_allclose(D, [[1., 2.]]) + assert_array_almost_equal(D, [[1., 2.]]) + X = csr_matrix(X) + Y = csr_matrix(Y) + D = euclidean_distances(X, Y) + assert_array_almost_equal(D, [[1., 2.]]) -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix], - ids=["dense", "sparse"]) -def test_euclidean_distances_with_norms(dtype, y_array_constr): - # check that we still get the right answers with {X,Y}_norm_squared - # and that we get a wrong answer with wrong {X,Y}_norm_squared rng = np.random.RandomState(0) - X = rng.random_sample((10, 10)).astype(dtype, copy=False) - Y = rng.random_sample((20, 10)).astype(dtype, copy=False) - - # norms will only be used if their dtype is float64 - X_norm_sq = (X.astype(np.float64) ** 2).sum(axis=1).reshape(1, -1) - Y_norm_sq = (Y.astype(np.float64) ** 2).sum(axis=1).reshape(1, -1) - - Y = y_array_constr(Y) + X = rng.random_sample((10, 4)) + Y = rng.random_sample((20, 4)) + X_norm_sq = (X ** 2).sum(axis=1).reshape(1, -1) + Y_norm_sq = (Y ** 2).sum(axis=1).reshape(1, -1) + # check that we still get the right answers with {X,Y}_norm_squared D1 = euclidean_distances(X, Y) D2 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq) D3 = euclidean_distances(X, Y, Y_norm_squared=Y_norm_sq) D4 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq, Y_norm_squared=Y_norm_sq) - assert_allclose(D2, D1) - assert_allclose(D3, D1) - assert_allclose(D4, D1) + assert_array_almost_equal(D2, D1) + assert_array_almost_equal(D3, D1) + assert_array_almost_equal(D4, D1) # check we get the wrong answer with wrong {X,Y}_norm_squared + X_norm_sq *= 0.5 + Y_norm_sq *= 0.5 wrong_D = euclidean_distances(X, Y, X_norm_squared=np.zeros_like(X_norm_sq), Y_norm_squared=np.zeros_like(Y_norm_sq)) @@ -665,7 +662,7 @@ def test_euclidean_distances_with_norms(dtype, y_array_constr): ids=["dense", "sparse"]) @pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix], ids=["dense", "sparse"]) -def test_euclidean_distances(dtype, x_array_constr, y_array_constr): +def test_euclidean_distances_cdist(dtype, x_array_constr, y_array_constr): # check that euclidean distances gives same result as scipy cdist # when X and Y != X are provided rng = np.random.RandomState(0) @@ -931,6 +928,187 @@ def test_cosine_distances(): assert np.all(D <= 2.) +def test_gower_distance_input_validation(): + with pytest.raises(TypeError, match="support sparse matrices"): + gower_distances(csr_matrix((2, 2))) + with pytest.raises(ValueError, match="X can not be None or empty"): + gower_distances(None) + + pd = pytest.importorskip("pandas") + + X = pd.DataFrame([['M', False, 222.22, 1], + ['F', True, 333.22, 2], + ['M', True, 1934.0, 4], + [np.nan, np.nan, np.nan, np.nan]]) + + # categorical features must be provided if any exist in the data + with pytest.raises(ValueError, match="could not convert string to float"): + gower_distances(X, scale=True) + + +def test_gower_distances_pairwise_equivalence(): + # the call to pairwise_distances should yield the same results + # even with parallel processing + rng = check_random_state(42) + X = rng.randint(size=(5, 10), low=0, high=10) + Y = rng.randint(size=(5, 10), low=-10, high=10) + gower = gower_distances( + X, Y, categorical_features=slice(0, 4), scale=False) + pw_gower = pairwise_distances(X, Y, metric='gower', n_jobs=2, + categorical_features=slice(0, 4), + scale=False) + assert_array_almost_equal(pw_gower, gower) + + +def test_gower_distances_cdist_equivalence(): + # test that gower is consistent with L1 and hamming on numerical and + # categorical only features respectively + rng = check_random_state(42) + X = rng.randint(size=(5, 10), low=0, high=10) + + l1 = cdist(X, X, metric='minkowski', p=1) / 10 + gower_numerical = gower_distances( + X, categorical_features=None, scale=False) + assert_array_almost_equal(l1, gower_numerical) + + hamming = cdist(X, X, metric='hamming') + gower_categorical = gower_distances( + X, categorical_features=slice(0, 10)) + assert_array_almost_equal(hamming, gower_categorical) + + # a mixed categorical and numerical should be equivalent to the combination + # of L1 and hamming + l1 = cdist(X[:, 6:], X[:, 6:], metric='minkowski', p=1) / 4 + hamming = cdist(X[:, :6], X[:, :6], metric='hamming') + gower = gower_distances(X, categorical_features=slice(0, 6), + scale=False) + assert_array_almost_equal( + gower, (hamming * 6 + l1 * 4) / 10) + + +def test_gower_distances_scaling(): + rng = check_random_state(42) + X = rng.randint(size=(5, 10), low=0, high=10) + # assuming the first 6 cols to be numerical, the rest categorical + + # test scaling of the numerical values + X_scaled = minmax_scale(X[:, 6:]) + l1 = cdist(X_scaled, X_scaled, metric='minkowski', p=1) / 4 + hamming = cdist(X[:, :6], X[:, :6], metric='hamming') + gower = gower_distances(X, categorical_features=slice(0, 6), + scale=True) + assert_array_almost_equal( + gower, (hamming * 6 + l1 * 4) / 10) + + # test providing the scaling factors to gower_similarity + scaler = MinMaxScaler().fit(X[:, 6:]) + gower = gower_distances(X, categorical_features=slice(0, 6), + scale=True, min_values=scaler.min_, + scale_factor=scaler.scale_) + assert_array_almost_equal( + gower, (hamming * 6 + l1 * 4) / 10) + + # make sure given min_ and scale_ are used by providing wrong values and + # getting wrong distances as output + gower = gower_distances(X, categorical_features=slice(0, 6), + scale=True, min_values=scaler.min_, + scale_factor=scaler.scale_ + 1) + with pytest.raises(AssertionError): + assert_array_almost_equal( + gower, (hamming * 6 + l1 * 4) / 10) + + # passing the scaling factor when Y is provided + Y = (X + 1)[::2, :] + trs = MinMaxScaler().fit(np.vstack((X[:, 6:], Y[:, 6:]))) + l1 = cdist(trs.transform(X[:, 6:]), trs.transform(Y[:, 6:]), + metric='minkowski', p=1) / 4 + hamming = cdist(X[:, :6], Y[:, :6], metric='hamming') + gower = gower_distances(X, Y, categorical_features=slice(0, 6), + scale=True, min_values=trs.min_, + scale_factor=trs.scale_) + assert_array_almost_equal( + gower, (hamming * 6 + l1 * 4) / 10) + assert gower.shape == (len(X), len(Y)) + + # passing Y w/o scaling factors should fail + with pytest.raises(ValueError, match="`scaling_factor` and `min_values` " + "must be provided when `Y` is provided and `scale=True`" + ): + gower_distances(X, Y, categorical_features=slice(0, 6), + scale=True) + + +@pytest.mark.parametrize('cat', + [slice(0, 2), + ['col_0', 'col_1'], + [0, 1], + make_column_selector(dtype_include=object), + [True, True, False, False]]) +def test_gower_distances_dataframe(cat): + # test different ways of specifying categorical features on a DataFrame + pd = pytest.importorskip('pandas') + + X = pd.DataFrame([['M', False, 222.22, 1], + ['F', True, 333.22, 2], + ['M', True, 1934.0, 4]], + columns=[f'col_{d}' for d in range(4)]) + X_cat = OrdinalEncoder().fit_transform(X.iloc[:, :2]) + X_num = np.array(X.iloc[:, 2:]) + + l1 = cdist(X_num, X_num, metric='minkowski', p=1) / 2 + hamming = cdist(X_cat, X_cat, metric='hamming') + gower = gower_distances(X, categorical_features=cat, + scale=False) + assert_array_almost_equal( + gower, (hamming * 2 + l1 * 2) / 4) + + +def test_gower_distances_nans(): + pd = pytest.importorskip("pandas") + + rng = check_random_state(42) + + # an all np.nan sample has an np.nan distance from other data points. + X = rng.rand(5, 10) + X[4, :] = np.nan + dists = gower_distances(X, categorical_features=slice(0, 6)) + assert np.all(np.isnan(dists[4, :])) + assert np.all(np.isnan(dists[:, 4])) + + # Test with a single nan, ranges don't matter for the test + X = [[9222.22, -11, 'M', 1], + [41934.0, -44, 'F', 1], + [1, 1, np.nan, 0]] + + Y = [[-222.22, 1, 'F', 0], + [1934.0, 4, 'M', 0], + [3000, 3000, 'F', 0]] + + # Simplified calculation of Gower distance for expected values + D_expected = np.zeros((3, 3)) + # This represents the number of non missing cols for each X, Y line + non_missing_cols = [4, 4, 3] + for i in range(0, 3): + for j in range(0, 3): + # The calculations below shows how it compares observation + # by observation, attribute by attribute. + D_expected[i][j] = ((abs(X[i][0] - Y[j][0]) + + abs(X[i][1] - Y[j][1]) + + ([1, 0][X[i][2] == Y[j][2]] + if (X[i][2] == X[i][2] and + Y[i][2] == Y[i][2]) else 0) + + abs(X[i][3] - Y[j][3])) / + non_missing_cols[i]) + + # pairwise_distances will convert the input to strings and np.nan would + # therefore be 'nan'. Passing DataFrames will avoid that. + D = pairwise_distances(pd.DataFrame(X), pd.DataFrame(Y), metric='gower', + n_jobs=2, + categorical_features=[2], + scale=False) + assert_array_almost_equal(D_expected, D) + + def test_haversine_distances(): # Check haversine distance with distances computation def slow_haversine_distances(x, y): diff --git a/sklearn/neighbors/_base.py b/sklearn/neighbors/_base.py index a1eebdcf78648..53d370cc6f3a2 100644 --- a/sklearn/neighbors/_base.py +++ b/sklearn/neighbors/_base.py @@ -47,7 +47,7 @@ VALID_METRICS_SPARSE = dict(ball_tree=[], kd_tree=[], brute=(PAIRWISE_DISTANCE_FUNCTIONS.keys() - - {'haversine', 'nan_euclidean'})) + {'haversine', 'nan_euclidean', 'gower'})) def _check_weights(weights): diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 9a89984880396..9befbb30aa8c6 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -24,7 +24,6 @@ from . import _joblib from ..exceptions import DataConversionWarning from .deprecation import deprecated -from .fixes import np_version from ._estimator_html_repr import estimator_html_repr from .validation import (as_float_array, assert_all_finite, @@ -169,19 +168,22 @@ def axis0_safe_slice(X, mask, len_mask): return np.zeros(shape=(0, X.shape[1])) -def _array_indexing(array, key, key_dtype, axis): +def _array_indexing(array, key, key_dtype, axis, complement): """Index an array or scipy.sparse consistently across NumPy version.""" - if np_version < (1, 12) or issparse(array): - # FIXME: Remove the check for NumPy when using >= 1.12 + if issparse(array): # check if we have an boolean array-likes to make the proper indexing if key_dtype == 'bool': key = np.asarray(key) if isinstance(key, tuple): key = list(key) + if complement: + mask = np.ones(array.shape[0] if axis == 0 else array.shape[1], bool) + mask[key] = False + key = mask return array[key] if axis == 0 else array[:, key] -def _pandas_indexing(X, key, key_dtype, axis): +def _pandas_indexing(X, key, key_dtype, axis, complement): """Index a pandas dataframe or a series.""" if hasattr(key, 'shape'): # Work-around for indexing with read-only key in pandas @@ -192,11 +194,28 @@ def _pandas_indexing(X, key, key_dtype, axis): key = list(key) # check whether we should index with loc or iloc indexer = X.iloc if key_dtype == 'int' else X.loc + if complement: + if key_dtype == 'str': + # we reject string keys for rows + key = _get_column_indices(X, key) + if isinstance(key, tuple): + key = list(key) + mask = np.ones(X.shape[0] if axis == 0 else X.shape[1], bool) + mask[key] = False + key = mask return indexer[:, key] if axis else indexer[key] -def _list_indexing(X, key, key_dtype): +def _list_indexing(X, key, key_dtype, complement): """Index a Python list.""" + if complement: + if isinstance(key, tuple): + key = list(key) + mask = np.ones(len(X), bool) + mask[key] = False + key = mask + key_dtype = 'bool' + if np.isscalar(key) or isinstance(key, slice): # key is a slice or a scalar return X[key] @@ -270,7 +289,7 @@ def _determine_key_type(key, accept_slice=True): raise ValueError(err_msg) -def _safe_indexing(X, indices, *, axis=0): +def _safe_indexing(X, indices, *, axis=0, complement=False): """Return rows, items or columns of X using indices. .. warning:: @@ -300,6 +319,9 @@ def _safe_indexing(X, indices, *, axis=0): axis : int, default=0 The axis along which `X` will be subsampled. `axis=0` will select rows while `axis=1` will select columns. + complement : bool, default=False + Whether to select the given columns or deselect them and return the + rest. Returns ------- @@ -341,11 +363,14 @@ def _safe_indexing(X, indices, *, axis=0): ) if hasattr(X, "iloc"): - return _pandas_indexing(X, indices, indices_dtype, axis=axis) + return _pandas_indexing(X, indices, indices_dtype, axis=axis, + complement=complement) elif hasattr(X, "shape"): - return _array_indexing(X, indices, indices_dtype, axis=axis) + return _array_indexing(X, indices, indices_dtype, axis=axis, + complement=complement) else: - return _list_indexing(X, indices, indices_dtype) + return _list_indexing(X, indices, indices_dtype, + complement=complement) def _get_column_indices(X, key): diff --git a/sklearn/utils/tests/test_utils.py b/sklearn/utils/tests/test_utils.py index 58f80209c7732..50f0afd6a9f9c 100644 --- a/sklearn/utils/tests/test_utils.py +++ b/sklearn/utils/tests/test_utils.py @@ -275,6 +275,26 @@ def test_safe_indexing_2d_container_axis_0(array_type, indices_type): ) +@pytest.mark.parametrize( + "array_type", ["list", "array", "sparse", "dataframe"] +) +@pytest.mark.parametrize( + "indices_type", ["list", "tuple", "array", "series", "slice"] +) +def test_safe_indexing_2d_container_axis_0_complement( + array_type, indices_type): + indices = [1, 2] + if indices_type == 'slice' and isinstance(indices[1], int): + indices[1] += 1 + array = _convert_container([[1, 2, 3], [4, 5, 6], [7, 8, 9], + [10, 11, 12]], array_type) + indices = _convert_container(indices, indices_type) + subset = _safe_indexing(array, indices, axis=0, complement=True) + assert_allclose_dense_sparse( + subset, _convert_container([[1, 2, 3], [10, 11, 12]], array_type) + ) + + @pytest.mark.parametrize("array_type", ["list", "array", "series"]) @pytest.mark.parametrize( "indices_type", ["list", "tuple", "array", "series", "slice"] @@ -321,6 +341,38 @@ def test_safe_indexing_2d_container_axis_1(array_type, indices_type, indices): ) +@pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe"]) +@pytest.mark.parametrize( + "indices_type", ["list", "tuple", "array", "series", "slice"] +) +@pytest.mark.parametrize("indices", [[1, 2], ["col_1", "col_2"]]) +def test_safe_indexing_2d_container_axis_1_complement( + array_type, indices_type, indices): + # validation of complement indexing + # we make a copy because indices is mutable and shared between tests + indices_converted = copy(indices) + if indices_type == 'slice' and isinstance(indices[1], int): + indices_converted[1] += 1 + + columns_name = ['col_0', 'col_1', 'col_2'] + array = _convert_container( + [[1, 2, 3], [4, 5, 6], [7, 8, 9]], array_type, columns_name + ) + indices_converted = _convert_container(indices_converted, indices_type) + + if isinstance(indices[0], str) and array_type != 'dataframe': + err_msg = ("Specifying the columns using strings is only supported " + "for pandas DataFrames") + with pytest.raises(ValueError, match=err_msg): + _safe_indexing(array, indices_converted, axis=1, complement=True) + else: + subset = _safe_indexing(array, indices_converted, axis=1, + complement=True) + assert_allclose_dense_sparse( + subset, _convert_container([[1], [4], [7]], array_type) + ) + + @pytest.mark.parametrize("array_read_only", [True, False]) @pytest.mark.parametrize("indices_read_only", [True, False]) @pytest.mark.parametrize("array_type", ["array", "sparse", "dataframe"])