diff --git a/doc/modules/manifold.rst b/doc/modules/manifold.rst index d9c65bcaf7bdb..19694ff0cb422 100644 --- a/doc/modules/manifold.rst +++ b/doc/modules/manifold.rst @@ -418,20 +418,19 @@ Multi-dimensional Scaling (MDS) representation of the data in which the distances respect well the distances in the original high-dimensional space. -In general, :class:`MDS` is a technique used for analyzing similarity or -dissimilarity data. It attempts to model similarity or dissimilarity data as -distances in a geometric space. The data can be ratings of similarity between +In general, :class:`MDS` is a technique used for analyzing +dissimilarity data. It attempts to model dissimilarities as +distances in a Euclidean space. The data can be ratings of dissimilarity between objects, interaction frequencies of molecules, or trade indices between countries. There exist two types of MDS algorithm: metric and non-metric. In -scikit-learn, the class :class:`MDS` implements both. In Metric MDS, the input -similarity matrix arises from a metric (and thus respects the triangular -inequality), the distances between output two points are then set to be as -close as possible to the similarity or dissimilarity data. In the non-metric -version, the algorithms will try to preserve the order of the distances, and +scikit-learn, the class :class:`MDS` implements both. In metric MDS, +the distances in the embedding space are set as +close as possible to the dissimilarity data. In the non-metric +version, the algorithm will try to preserve the order of the distances, and hence seek for a monotonic relationship between the distances in the embedded -space and the similarities/dissimilarities. +space and the input dissimilarities. .. figure:: ../auto_examples/manifold/images/sphx_glr_plot_lle_digits_010.png :target: ../auto_examples/manifold/plot_lle_digits.html @@ -439,46 +438,45 @@ space and the similarities/dissimilarities. :scale: 50 -Let :math:`S` be the similarity matrix, and :math:`X` the coordinates of the -:math:`n` input points. Disparities :math:`\hat{d}_{ij}` are transformation of -the similarities chosen in some optimal ways. The objective, called the -stress, is then defined by :math:`\sum_{i < j} d_{ij}(X) - \hat{d}_{ij}(X)` +Let :math:`\delta_{ij}` be the dissimilarity matrix between the +:math:`n` input points (possibly arising as some pairwise distances +:math:`d_{ij}(X)` between the coordinates :math:`X` of the input points). +Disparities :math:`\hat{d}_{ij} = f(\delta_{ij})` are some transformation of +the dissimilarities. The MDS objective, called the raw stress, is then +defined by :math:`\sum_{i < j} (\hat{d}_{ij} - d_{ij}(Z))^2`, +where :math:`d_{ij}(Z)` are the pairwise distances between the +coordinates :math:`Z` of the embedded points. .. dropdown:: Metric MDS - The simplest metric :class:`MDS` model, called *absolute MDS*, disparities are defined by - :math:`\hat{d}_{ij} = S_{ij}`. With absolute MDS, the value :math:`S_{ij}` - should then correspond exactly to the distance between point :math:`i` and - :math:`j` in the embedding point. - - Most commonly, disparities are set to :math:`\hat{d}_{ij} = b S_{ij}`. + In the metric :class:`MDS` model (sometimes also called *absolute MDS*), + disparities are simply equal to the input dissimilarities + :math:`\hat{d}_{ij} = \delta_{ij}`. .. dropdown:: Nonmetric MDS Non metric :class:`MDS` focuses on the ordination of the data. If - :math:`S_{ij} > S_{jk}`, then the embedding should enforce :math:`d_{ij} < - d_{jk}`. For this reason, we discuss it in terms of dissimilarities - (:math:`\delta_{ij}`) instead of similarities (:math:`S_{ij}`). Note that - dissimilarities can easily be obtained from similarities through a simple - transform, e.g. :math:`\delta_{ij}=c_1-c_2 S_{ij}` for some real constants - :math:`c_1, c_2`. A simple algorithm to enforce proper ordination is to use a - monotonic regression of :math:`d_{ij}` on :math:`\delta_{ij}`, yielding - disparities :math:`\hat{d}_{ij}` in the same order as :math:`\delta_{ij}`. - - A trivial solution to this problem is to set all the points on the origin. In - order to avoid that, the disparities :math:`\hat{d}_{ij}` are normalized. Note - that since we only care about relative ordering, our objective should be + :math:`\delta_{ij} > \delta_{kl}`, then the embedding + seeks to enforce :math:`d_{ij}(Z) > d_{kl}(Z)`. A simple algorithm + to enforce proper ordination is to use an + isotonic regression of :math:`d_{ij}(Z)` on :math:`\delta_{ij}`, yielding + disparities :math:`\hat{d}_{ij}` that are a monotonic transformation + of dissimilarities :math:`\delta_{ij}` and hence having the same ordering. + This is done repeatedly after every step of the optimization algorithm. + In order to avoid the trivial solution where all embedding points are + overlapping, the disparities :math:`\hat{d}_{ij}` are normalized. + + Note that since we only care about relative ordering, our objective should be invariant to simple translation and scaling, however the stress used in metric - MDS is sensitive to scaling. To address this, non-metric MDS may use a - normalized stress, known as Stress-1 defined as + MDS is sensitive to scaling. To address this, non-metric MDS returns + normalized stress, also known as Stress-1, defined as .. math:: - \sqrt{\frac{\sum_{i < j} (d_{ij} - \hat{d}_{ij})^2}{\sum_{i < j} d_{ij}^2}}. + \sqrt{\frac{\sum_{i < j} (\hat{d}_{ij} - d_{ij}(Z))^2}{\sum_{i < j} + d_{ij}(Z)^2}}. - The use of normalized Stress-1 can be enabled by setting `normalized_stress=True`, - however it is only compatible with the non-metric MDS problem and will be ignored - in the metric case. + Normalized Stress-1 is returned if `normalized_stress=True`. .. figure:: ../auto_examples/manifold/images/sphx_glr_plot_mds_001.png :target: ../auto_examples/manifold/plot_mds.html @@ -487,6 +485,10 @@ stress, is then defined by :math:`\sum_{i < j} d_{ij}(X) - \hat{d}_{ij}(X)` .. rubric:: References +* `"More on Multidimensional Scaling and Unfolding in R: smacof Version 2" + `_ + Mair P, Groenen P., de Leeuw J. Journal of Statistical Software (2022) + * `"Modern Multidimensional Scaling - Theory and Applications" `_ Borg, I.; Groenen P. Springer Series in Statistics (1997) diff --git a/doc/whats_new/upcoming_changes/sklearn.manifold/30514.fix.rst b/doc/whats_new/upcoming_changes/sklearn.manifold/30514.fix.rst new file mode 100644 index 0000000000000..7f4e4104446dc --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.manifold/30514.fix.rst @@ -0,0 +1,4 @@ +- :class:`manifold.MDS` now correctly handles non-metric MDS. Furthermore, + the returned stress value now corresponds to the returned embedding and + normalized stress is now allowed for metric MDS. + By :user:`Dmitry Kobak ` diff --git a/examples/manifold/plot_mds.py b/examples/manifold/plot_mds.py index c572e792ac71b..afea676b245a8 100644 --- a/examples/manifold/plot_mds.py +++ b/examples/manifold/plot_mds.py @@ -21,31 +21,34 @@ from sklearn.decomposition import PCA from sklearn.metrics import euclidean_distances +# Generate the data EPSILON = np.finfo(np.float32).eps n_samples = 20 -seed = np.random.RandomState(seed=3) -X_true = seed.randint(0, 20, 2 * n_samples).astype(float) +rng = np.random.RandomState(seed=3) +X_true = rng.randint(0, 20, 2 * n_samples).astype(float) X_true = X_true.reshape((n_samples, 2)) + # Center the data X_true -= X_true.mean() -similarities = euclidean_distances(X_true) +# Compute pairwise Euclidean distances +distances = euclidean_distances(X_true) -# Add noise to the similarities -noise = np.random.rand(n_samples, n_samples) +# Add noise to the distances +noise = rng.rand(n_samples, n_samples) noise = noise + noise.T -noise[np.arange(noise.shape[0]), np.arange(noise.shape[0])] = 0 -similarities += noise +np.fill_diagonal(noise, 0) +distances += noise mds = manifold.MDS( n_components=2, max_iter=3000, eps=1e-9, - random_state=seed, + random_state=42, dissimilarity="precomputed", n_jobs=1, ) -pos = mds.fit(similarities).embedding_ +X_mds = mds.fit(distances).embedding_ nmds = manifold.MDS( n_components=2, @@ -53,47 +56,52 @@ max_iter=3000, eps=1e-12, dissimilarity="precomputed", - random_state=seed, + random_state=42, n_jobs=1, n_init=1, ) -npos = nmds.fit_transform(similarities, init=pos) +X_nmds = nmds.fit_transform(distances) # Rescale the data -pos *= np.sqrt((X_true**2).sum()) / np.sqrt((pos**2).sum()) -npos *= np.sqrt((X_true**2).sum()) / np.sqrt((npos**2).sum()) +X_mds *= np.sqrt((X_true**2).sum()) / np.sqrt((X_mds**2).sum()) +X_nmds *= np.sqrt((X_true**2).sum()) / np.sqrt((X_nmds**2).sum()) # Rotate the data -clf = PCA(n_components=2) -X_true = clf.fit_transform(X_true) - -pos = clf.fit_transform(pos) - -npos = clf.fit_transform(npos) +pca = PCA(n_components=2) +X_true = pca.fit_transform(X_true) +X_mds = pca.fit_transform(X_mds) +X_nmds = pca.fit_transform(X_nmds) + +# Align the sign of PCs +for i in [0, 1]: + if np.corrcoef(X_mds[:, i], X_true[:, i])[0, 1] < 0: + X_mds[:, i] *= -1 + if np.corrcoef(X_nmds[:, i], X_true[:, i])[0, 1] < 0: + X_nmds[:, i] *= -1 fig = plt.figure(1) ax = plt.axes([0.0, 0.0, 1.0, 1.0]) s = 100 plt.scatter(X_true[:, 0], X_true[:, 1], color="navy", s=s, lw=0, label="True Position") -plt.scatter(pos[:, 0], pos[:, 1], color="turquoise", s=s, lw=0, label="MDS") -plt.scatter(npos[:, 0], npos[:, 1], color="darkorange", s=s, lw=0, label="NMDS") +plt.scatter(X_mds[:, 0], X_mds[:, 1], color="turquoise", s=s, lw=0, label="MDS") +plt.scatter(X_nmds[:, 0], X_nmds[:, 1], color="darkorange", s=s, lw=0, label="NMDS") plt.legend(scatterpoints=1, loc="best", shadow=False) -similarities = similarities.max() / (similarities + EPSILON) * 100 -np.fill_diagonal(similarities, 0) # Plot the edges -start_idx, end_idx = np.where(pos) +start_idx, end_idx = np.where(X_mds) # a sequence of (*line0*, *line1*, *line2*), where:: # linen = (x0, y0), (x1, y1), ... (xm, ym) segments = [ - [X_true[i, :], X_true[j, :]] for i in range(len(pos)) for j in range(len(pos)) + [X_true[i, :], X_true[j, :]] for i in range(len(X_true)) for j in range(len(X_true)) ] -values = np.abs(similarities) +edges = distances.max() / (distances + EPSILON) * 100 +np.fill_diagonal(edges, 0) +edges = np.abs(edges) lc = LineCollection( - segments, zorder=0, cmap=plt.cm.Blues, norm=plt.Normalize(0, values.max()) + segments, zorder=0, cmap=plt.cm.Blues, norm=plt.Normalize(0, edges.max()) ) -lc.set_array(similarities.flatten()) +lc.set_array(edges.flatten()) lc.set_linewidths(np.full(len(segments), 0.5)) ax.add_collection(lc) diff --git a/sklearn/manifold/_mds.py b/sklearn/manifold/_mds.py index dc9f88b502da5..07d492bdcd34d 100644 --- a/sklearn/manifold/_mds.py +++ b/sklearn/manifold/_mds.py @@ -70,12 +70,14 @@ def _smacof_single( See :term:`Glossary `. normalized_stress : bool, default=False - Whether use and return normed stress value (Stress-1) instead of raw - stress calculated by default. Only supported in non-metric MDS. The - caller must ensure that if `normalized_stress=True` then `metric=False` + Whether use and return normalized stress value (Stress-1) instead of raw + stress. .. versionadded:: 1.2 + .. versionchanged:: 1.7 + Normalized stress is now supported for metric MDS as well. + Returns ------- X : ndarray of shape (n_samples, n_components) @@ -84,7 +86,7 @@ def _smacof_single( stress : float The final value of the stress (sum of squared distance of the disparities and the distances for all constrained points). - If `normalized_stress=True`, and `metric=False` returns Stress-1. + If `normalized_stress=True`, returns Stress-1. A value of 0 indicates "perfect" fit, 0.025 excellent, 0.05 good, 0.1 fair, and 0.2 poor [1]_. @@ -107,8 +109,8 @@ def _smacof_single( n_samples = dissimilarities.shape[0] random_state = check_random_state(random_state) - sim_flat = ((1 - np.tri(n_samples)) * dissimilarities).ravel() - sim_flat_w = sim_flat[sim_flat != 0] + dissimilarities_flat = ((1 - np.tri(n_samples)) * dissimilarities).ravel() + dissimilarities_flat_w = dissimilarities_flat[dissimilarities_flat != 0] if init is None: # Randomly choose initial configuration X = random_state.uniform(size=n_samples * n_components) @@ -121,49 +123,63 @@ def _smacof_single( "init matrix should be of shape (%d, %d)" % (n_samples, n_components) ) X = init + distances = euclidean_distances(X) + + # Out of bounds condition cannot happen because we are transforming + # the training set here, but does sometimes get triggered in + # practice due to machine precision issues. Hence "clip". + ir = IsotonicRegression(out_of_bounds="clip") old_stress = None - ir = IsotonicRegression() for it in range(max_iter): # Compute distance and monotonic regression - dis = euclidean_distances(X) - if metric: disparities = dissimilarities else: - dis_flat = dis.ravel() + distances_flat = distances.ravel() # dissimilarities with 0 are considered as missing values - dis_flat_w = dis_flat[sim_flat != 0] - - # Compute the disparities using a monotonic regression - disparities_flat = ir.fit_transform(sim_flat_w, dis_flat_w) - disparities = dis_flat.copy() - disparities[sim_flat != 0] = disparities_flat + distances_flat_w = distances_flat[dissimilarities_flat != 0] + + # Compute the disparities using isotonic regression. + # For the first SMACOF iteration, use scaled original dissimilarities. + # (This choice follows the R implementation described in this paper: + # https://www.jstatsoft.org/article/view/v102i10) + if it < 1: + disparities_flat = dissimilarities_flat_w + else: + disparities_flat = ir.fit_transform( + dissimilarities_flat_w, distances_flat_w + ) + disparities = np.zeros_like(distances_flat) + disparities[dissimilarities_flat != 0] = disparities_flat disparities = disparities.reshape((n_samples, n_samples)) disparities *= np.sqrt( (n_samples * (n_samples - 1) / 2) / (disparities**2).sum() ) + disparities = disparities + disparities.T - # Compute stress - stress = ((dis.ravel() - disparities.ravel()) ** 2).sum() / 2 - if normalized_stress: - stress = np.sqrt(stress / ((disparities.ravel() ** 2).sum() / 2)) # Update X using the Guttman transform - dis[dis == 0] = 1e-5 - ratio = disparities / dis + distances[distances == 0] = 1e-5 + ratio = disparities / distances B = -ratio B[np.arange(len(B)), np.arange(len(B))] += ratio.sum(axis=1) X = 1.0 / n_samples * np.dot(B, X) - dis = np.sqrt((X**2).sum(axis=1)).sum() - if verbose >= 2: - print("it: %d, stress %s" % (it, stress)) + # Compute stress + distances = euclidean_distances(X) + stress = ((distances.ravel() - disparities.ravel()) ** 2).sum() / 2 + if normalized_stress: + stress = np.sqrt(stress / ((disparities.ravel() ** 2).sum() / 2)) + + normalization = np.sqrt((X**2).sum(axis=1)).sum() + if verbose >= 2: # pragma: no cover + print(f"Iteration {it}, stress {stress:.4f}") if old_stress is not None: - if (old_stress - stress / dis) < eps: - if verbose: - print("breaking at iteration %d with stress %s" % (it, stress)) + if (old_stress - stress / normalization) < eps: + if verbose: # pragma: no cover + print("Convergence criterion reached.") break - old_stress = stress / dis + old_stress = stress / normalization return X, stress, it + 1 @@ -275,14 +291,18 @@ def smacof( Whether or not to return the number of iterations. normalized_stress : bool or "auto" default="auto" - Whether use and return normed stress value (Stress-1) instead of raw - stress calculated by default. Only supported in non-metric MDS. + Whether to return normalized stress value (Stress-1) instead of raw + stress. By default, metric MDS returns raw stress while non-metric MDS + returns normalized stress. .. versionadded:: 1.2 .. versionchanged:: 1.4 The default value changed from `False` to `"auto"` in version 1.4. + .. versionchanged:: 1.7 + Normalized stress is now supported for metric MDS as well. + Returns ------- X : ndarray of shape (n_samples, n_components) @@ -291,7 +311,7 @@ def smacof( stress : float The final value of the stress (sum of squared distance of the disparities and the distances for all constrained points). - If `normalized_stress=True`, and `metric=False` returns Stress-1. + If `normalized_stress=True`, returns Stress-1. A value of 0 indicates "perfect" fit, 0.025 excellent, 0.05 good, 0.1 fair, and 0.2 poor [1]_. @@ -318,12 +338,12 @@ def smacof( >>> X = np.array([[0, 1, 2], [1, 0, 3],[2, 3, 0]]) >>> dissimilarities = euclidean_distances(X) >>> mds_result, stress = smacof(dissimilarities, n_components=2, random_state=42) - >>> mds_result - array([[ 0.05... -1.07... ], - [ 1.74..., -0.75...], - [-1.79..., 1.83...]]) - >>> stress - np.float64(0.0012...) + >>> np.round(mds_result, 5) + array([[ 0.05352, -1.07253], + [ 1.74231, -0.75675], + [-1.79583, 1.82928]]) + >>> np.round(stress, 5).item() + 0.00128 """ dissimilarities = check_array(dissimilarities) @@ -332,11 +352,6 @@ def smacof( if normalized_stress == "auto": normalized_stress = not metric - if normalized_stress and metric: - raise ValueError( - "Normalized stress is not supported for metric MDS. Either set" - " `normalized_stress=False` or use `metric=False`." - ) if hasattr(init, "__array__"): init = np.asarray(init).copy() if not n_init == 1: @@ -449,14 +464,18 @@ class MDS(BaseEstimator): ``fit_transform``. normalized_stress : bool or "auto" default="auto" - Whether use and return normed stress value (Stress-1) instead of raw - stress calculated by default. Only supported in non-metric MDS. + Whether use and return normalized stress value (Stress-1) instead of raw + stress. By default, metric MDS uses raw stress while non-metric MDS uses + normalized stress. .. versionadded:: 1.2 .. versionchanged:: 1.4 The default value changed from `False` to `"auto"` in version 1.4. + .. versionchanged:: 1.7 + Normalized stress is now supported for metric MDS as well. + Attributes ---------- embedding_ : ndarray of shape (n_samples, n_components) @@ -465,7 +484,7 @@ class MDS(BaseEstimator): stress_ : float The final value of the stress (sum of squared distance of the disparities and the distances for all constrained points). - If `normalized_stress=True`, and `metric=False` returns Stress-1. + If `normalized_stress=True`, returns Stress-1. A value of 0 indicates "perfect" fit, 0.025 excellent, 0.05 good, 0.1 fair, and 0.2 poor [1]_. diff --git a/sklearn/manifold/tests/test_mds.py b/sklearn/manifold/tests/test_mds.py index 2d286ef0942bf..b34f030b79895 100644 --- a/sklearn/manifold/tests/test_mds.py +++ b/sklearn/manifold/tests/test_mds.py @@ -4,6 +4,7 @@ import pytest from numpy.testing import assert_allclose, assert_array_almost_equal +from sklearn.datasets import load_digits from sklearn.manifold import _mds as mds from sklearn.metrics import euclidean_distances @@ -20,6 +21,74 @@ def test_smacof(): assert_array_almost_equal(X, X_true, decimal=3) +def test_nonmetric_lower_normalized_stress(): + # Testing that nonmetric MDS results in lower normalized stess compared + # compared to metric MDS (non-regression test for issue 27028) + sim = np.array([[0, 5, 3, 4], [5, 0, 2, 2], [3, 2, 0, 1], [4, 2, 1, 0]]) + Z = np.array([[-0.266, -0.539], [0.451, 0.252], [0.016, -0.238], [-0.200, 0.524]]) + + _, stress1 = mds.smacof( + sim, init=Z, n_components=2, max_iter=1000, n_init=1, normalized_stress=True + ) + + _, stress2 = mds.smacof( + sim, + init=Z, + n_components=2, + max_iter=1000, + n_init=1, + normalized_stress=True, + metric=False, + ) + assert stress1 > stress2 + + +def test_nonmetric_mds_optimization(): + # Test that stress is decreasing during nonmetric MDS optimization + # (non-regression test for issue 27028) + X, _ = load_digits(return_X_y=True) + rng = np.random.default_rng(seed=42) + ind_subset = rng.choice(len(X), size=200, replace=False) + X = X[ind_subset] + + mds_est = mds.MDS( + n_components=2, + n_init=1, + eps=1e-15, + max_iter=2, + metric=False, + random_state=42, + ).fit(X) + stress_after_2_iter = mds_est.stress_ + + mds_est = mds.MDS( + n_components=2, + n_init=1, + eps=1e-15, + max_iter=3, + metric=False, + random_state=42, + ).fit(X) + stress_after_3_iter = mds_est.stress_ + + assert stress_after_2_iter > stress_after_3_iter + + +@pytest.mark.parametrize("metric", [True, False]) +def test_mds_recovers_true_data(metric): + X = np.array([[1, 1], [1, 4], [1, 5], [3, 3]]) + mds_est = mds.MDS( + n_components=2, + n_init=1, + eps=1e-15, + max_iter=1000, + metric=metric, + random_state=42, + ).fit(X) + stress = mds_est.stress_ + assert_allclose(stress, 0, atol=1e-10) + + def test_smacof_error(): # Not symmetric similarity matrix: sim = np.array([[0, 5, 9, 4], [5, 0, 2, 2], [3, 2, 0, 1], [4, 2, 1, 0]]) @@ -59,17 +128,6 @@ def test_normed_stress(k): assert_allclose(X1, X2, rtol=1e-5) -def test_normalize_metric_warning(): - """ - Test that a UserWarning is emitted when using normalized stress with - metric-MDS. - """ - msg = "Normalized stress is not supported" - sim = np.array([[0, 5, 3, 4], [5, 0, 2, 2], [3, 2, 0, 1], [4, 2, 1, 0]]) - with pytest.raises(ValueError, match=msg): - mds.smacof(sim, metric=True, normalized_stress=True) - - @pytest.mark.parametrize("metric", [True, False]) def test_normalized_stress_auto(metric, monkeypatch): rng = np.random.RandomState(0) @@ -85,3 +143,39 @@ def test_normalized_stress_auto(metric, monkeypatch): mds.smacof(dist, metric=metric, normalized_stress="auto", random_state=rng) assert mock.call_args[1]["normalized_stress"] != metric + + +def test_isotonic_outofbounds(): + # This particular configuration can trigger out of bounds error + # in the isotonic regression (non-regression test for issue 26999) + dis = np.array( + [ + [0.0, 1.732050807568877, 1.7320508075688772], + [1.732050807568877, 0.0, 6.661338147750939e-16], + [1.7320508075688772, 6.661338147750939e-16, 0.0], + ] + ) + init = np.array( + [ + [0.08665881585055124, 0.7939114643387546], + [0.9959834154297658, 0.7555546025640025], + [0.8766008278401566, 0.4227358815811242], + ] + ) + mds.smacof(dis, init=init, metric=False, n_init=1) + + +def test_returned_stress(): + # Test that the final stress corresponds to the final embedding + # (non-regression test for issue 16846) + X = np.array([[1, 1], [1, 4], [1, 5], [3, 3]]) + D = euclidean_distances(X) + + mds_est = mds.MDS(n_components=2, random_state=42).fit(X) + Z = mds_est.embedding_ + stress = mds_est.stress_ + + D_mds = euclidean_distances(Z) + stress_Z = ((D_mds.ravel() - D.ravel()) ** 2).sum() / 2 + + assert_allclose(stress, stress_Z)