Skip to content

Implement classical MDS #31322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions doc/modules/manifold.rst
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,15 @@ coordinates :math:`Z` of the embedded points.
:align: center
:scale: 60

Apart from that, there is a version called *classical MDS*, also known as
*principal coordinates analysis (PCoA)* or *Torgerson's scaling*, and implemented
in the separate :class:`ClassicalMDS` class. Classical MDS replaces the stress
loss function with a different loss function called *strain*, which allows
exact solution in terms of eigendecomposition of the double-centered dissimilarity
matrix. If the dissimilarity matrix is the matrix of pairwise Euclidean distances
between some vectors, then classical MDS is equivalent to PCA of this set of
vectors.

.. rubric:: References

* `"More on Multidimensional Scaling and Unfolding in R: smacof Version 2"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
- :class:`manifold.ClassicalMDS` was implemented to perform classical MDS
(eigendecomposition of the double-centered distance matrix). Furthermore,
:class:`manifold.MDS` now supports different pairwise distance metrics,
not only the Euclidean metric.
By :user:`Dmitry Kobak <dkobak>`
2 changes: 2 additions & 0 deletions sklearn/manifold/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

from ._classical_mds import ClassicalMDS
from ._isomap import Isomap
from ._locally_linear import LocallyLinearEmbedding, locally_linear_embedding
from ._mds import MDS, smacof
Expand All @@ -12,6 +13,7 @@
__all__ = [
"MDS",
"TSNE",
"ClassicalMDS",
"Isomap",
"LocallyLinearEmbedding",
"SpectralEmbedding",
Expand Down
182 changes: 182 additions & 0 deletions sklearn/manifold/_classical_mds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
"""
Classical multi-dimensional scaling (classical MDS).
"""

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

from numbers import Integral

import numpy as np
from scipy import linalg

from ..base import BaseEstimator, _fit_context
from ..metrics import pairwise_distances
from ..utils import check_symmetric
from ..utils._param_validation import Interval
from ..utils.extmath import svd_flip
from ..utils.validation import validate_data


class ClassicalMDS(BaseEstimator):
"""Classical multidimensional scaling.

Read more in the :ref:`User Guide <multidimensional_scaling>`.

Parameters
----------
n_components : int, default=2
Number of embedding dimensions.

dissimilarity : str or callable, default='euclidean'
Metric to use for dissimilarity computation. Default is "euclidean".
See the documentation of `scipy.spatial.distance
<https://docs.scipy.org/doc/scipy/reference/spatial.distance.html>`_ and
the metrics listed in
:class:`~sklearn.metrics.pairwise.distance_metrics` for valid metric
values.

If metric is "precomputed", X is assumed to be a distance matrix and
must be square during fit.

If metric is a callable function, it takes two arrays representing 1D
vectors as inputs and must return one value indicating the distance
between those vectors. This works for Scipy's metrics, but is less
efficient than passing the metric name as a string.

Attributes
----------
embedding_ : ndarray of shape (n_samples, n_components)
Stores the position of the dataset in the embedding space.

dissimilarity_matrix_ : ndarray of shape (n_samples, n_samples)
Pairwise dissimilarities between the points.

eigenvalues_ : ndarray of shape (n_components,)
Eigenvalues of the double-centered dissimilarity matrix, corresponding
to each of the selected components. They are equal to the squared 2-norms
of the `n_components` variables in the embedding space.

n_features_in_ : int
Number of features seen during :term:`fit`.

feature_names_in_ : ndarray of shape (`n_features_in_`,)
Names of features seen during :term:`fit`. Defined only when `X`
has feature names that are all strings.

See Also
--------
sklearn.decomposition.PCA : Principal component analysis.
MDS : Metric and non-metric MDS.

References
----------
.. [1] "Modern Multidimensional Scaling - Theory and Applications" Borg, I.;
Groenen P. Springer Series in Statistics (1997)

Examples
--------
>>> from sklearn.datasets import load_digits
>>> from sklearn.manifold import ClassicalMDS
>>> X, _ = load_digits(return_X_y=True)
>>> X.shape
(1797, 64)
>>> cmds = ClassicalMDS(n_components=2)
>>> X_emb = cmds.fit_transform(X[:100])
>>> X_emb.shape
(100, 2)
"""

_parameter_constraints: dict = {
"n_components": [Interval(Integral, 1, None, closed="left")],
"dissimilarity": [str, callable],
}

def __init__(
self,
n_components=2,
*,
dissimilarity="euclidean",
):
self.n_components = n_components
self.dissimilarity = dissimilarity

def __sklearn_tags__(self):
tags = super().__sklearn_tags__()
tags.input_tags.pairwise = self.dissimilarity == "precomputed"
return tags

def fit(self, X, y=None):
"""
Compute the embedding positions.

Parameters
----------
X : array-like of shape (n_samples, n_features) or \
(n_samples, n_samples)
Input data. If ``dissimilarity=='precomputed'``, the input should
be the dissimilarity matrix.

y : Ignored
Not used, present for API consistency by convention.

Returns
-------
self : object
Fitted estimator.
"""
self.fit_transform(X)
return self

@_fit_context(prefer_skip_nested_validation=True)
def fit_transform(self, X, y=None):
"""
Compute the embedding positions.

Parameters
----------
X : array-like of shape (n_samples, n_features) or \
(n_samples, n_samples)
Input data. If ``dissimilarity=='precomputed'``, the input should
be the dissimilarity matrix.

y : Ignored
Not used, present for API consistency by convention.

Returns
-------
X_new : ndarray of shape (n_samples, n_components)
The embedding coordinates.
"""

X = validate_data(self, X)

if self.dissimilarity == "precomputed":
self.dissimilarity_matrix_ = X
self.dissimilarity_matrix_ = check_symmetric(
self.dissimilarity_matrix_, raise_exception=True
)
else:
self.dissimilarity_matrix_ = pairwise_distances(
X, metric=self.dissimilarity
)

# Double centering
B = self.dissimilarity_matrix_**2
B = B.astype(np.float64)
B -= np.mean(B, axis=0)
B -= np.mean(B, axis=1, keepdims=True)
B *= -0.5

# Eigendecomposition
w, U = linalg.eigh(B)
w = w[::-1][: self.n_components]
U = U[:, ::-1][:, : self.n_components]

# Set the signs of eigenvectors to enforce deterministic output
U, _ = svd_flip(U, None)

self.embedding_ = np.sqrt(w) * U
self.eigenvalues_ = w

return self.embedding_
44 changes: 27 additions & 17 deletions sklearn/manifold/_mds.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from ..base import BaseEstimator, _fit_context
from ..isotonic import IsotonicRegression
from ..metrics import euclidean_distances
from ..metrics import euclidean_distances, pairwise_distances
from ..utils import check_array, check_random_state, check_symmetric
from ..utils._param_validation import Interval, StrOptions, validate_params
from ..utils.parallel import Parallel, delayed
Expand Down Expand Up @@ -479,15 +479,25 @@ class MDS(BaseEstimator):
Pass an int for reproducible results across multiple function calls.
See :term:`Glossary <random_state>`.

dissimilarity : {'euclidean', 'precomputed'}, default='euclidean'
Dissimilarity measure to use:
dissimilarity : str or callable, default='euclidean'
Metric to use for dissimilarity computation. Default is "euclidean".
See the documentation of `scipy.spatial.distance
<https://docs.scipy.org/doc/scipy/reference/spatial.distance.html>`_ and
the metrics listed in
:class:`~sklearn.metrics.pairwise.distance_metrics` for valid metric
values.

- 'euclidean':
Pairwise Euclidean distances between points in the dataset.
If metric is "precomputed", X is assumed to be a distance matrix and
must be square during fit.

- 'precomputed':
Pre-computed dissimilarities are passed directly to ``fit`` and
``fit_transform``.
If metric is a callable function, it takes two arrays representing 1D
vectors as inputs and must return one value indicating the distance
between those vectors. This works for Scipy's metrics, but is less
efficient than passing the metric name as a string.

.. versionchanged:: 1.8
All metrics supported by `sklearn.metrics.pairwise.distance_metrics`
are now allowed, not only Euclidean.

normalized_stress : bool or "auto" default="auto"
Whether to return normalized stress value (Stress-1) instead of raw
Expand Down Expand Up @@ -515,12 +525,7 @@ class MDS(BaseEstimator):
0.1 fair, and 0.2 poor [1]_.

dissimilarity_matrix_ : ndarray of shape (n_samples, n_samples)
Pairwise dissimilarities between the points. Symmetric matrix that:

- either uses a custom dissimilarity matrix by setting `dissimilarity`
to 'precomputed';
- or constructs a dissimilarity matrix from data using
Euclidean distances.
Pairwise dissimilarities between the points.

n_features_in_ : int
Number of features seen during :term:`fit`.
Expand Down Expand Up @@ -586,7 +591,7 @@ class MDS(BaseEstimator):
"eps": [Interval(Real, 0.0, None, closed="left")],
"n_jobs": [None, Integral],
"random_state": ["random_state"],
"dissimilarity": [StrOptions({"euclidean", "precomputed"})],
"dissimilarity": [str, callable],
"normalized_stress": ["boolean", StrOptions({"auto"})],
}

Expand Down Expand Up @@ -693,8 +698,13 @@ def fit_transform(self, X, y=None, init=None):

if self.dissimilarity == "precomputed":
self.dissimilarity_matrix_ = X
elif self.dissimilarity == "euclidean":
self.dissimilarity_matrix_ = euclidean_distances(X)
self.dissimilarity_matrix_ = check_symmetric(
self.dissimilarity_matrix_, raise_exception=True
)
else:
self.dissimilarity_matrix_ = pairwise_distances(
X, metric=self.dissimilarity
)

self.embedding_, self.stress_, self.n_iter_ = smacof(
self.dissimilarity_matrix_,
Expand Down
51 changes: 51 additions & 0 deletions sklearn/manifold/tests/test_classical_mds.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import numpy as np
import pytest
from numpy.testing import assert_array_almost_equal

from sklearn.datasets import load_iris
from sklearn.decomposition import PCA
from sklearn.manifold import ClassicalMDS
from sklearn.metrics import euclidean_distances


def test_classical_mds_equivalent_to_pca():
X, _ = load_iris(return_X_y=True)

cmds = ClassicalMDS(n_components=2, dissimilarity="euclidean")
pca = PCA(n_components=2)

Z1 = cmds.fit_transform(X)
Z2 = pca.fit_transform(X)

# Swap the signs if necessary
for comp in range(2):
if Z1[0, comp] < 0 and Z2[0, comp] > 0:
Z2[:, comp] *= -1

Check warning on line 23 in sklearn/manifold/tests/test_classical_mds.py

View check run for this annotation

Codecov / codecov/patch

sklearn/manifold/tests/test_classical_mds.py#L23

Added line #L23 was not covered by tests

assert_array_almost_equal(Z1, Z2)

assert_array_almost_equal(np.sqrt(cmds.eigenvalues_), pca.singular_values_)


def test_classical_mds_equivalent_on_data_and_distances():
X, _ = load_iris(return_X_y=True)

cmds = ClassicalMDS(n_components=2, dissimilarity="euclidean")
Z1 = cmds.fit_transform(X)

cmds = ClassicalMDS(n_components=2, dissimilarity="precomputed")
Z2 = cmds.fit_transform(euclidean_distances(X))

assert_array_almost_equal(Z1, Z2)


def test_classical_mds_wrong_inputs():
# Non-symmetric input
dissim = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8]])
with pytest.raises(ValueError, match="Array must be symmetric"):
ClassicalMDS(dissimilarity="precomputed").fit(dissim)

# Non-square input
dissim = np.array([[0, 1, 2], [3, 4, 5]])
with pytest.raises(ValueError, match="array must be 2-dimensional and square"):
ClassicalMDS(dissimilarity="precomputed").fit(dissim)