Skip to content

TST use global_dtype in sklearn/manifold/tests/test_isomap.py #22673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Nov 29, 2022
210 changes: 117 additions & 93 deletions sklearn/manifold/tests/test_isomap.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,88 @@
from itertools import product
import numpy as np
import math
from numpy.testing import (
assert_almost_equal,
assert_array_almost_equal,
assert_array_equal,
)
import pytest

from sklearn import datasets
from sklearn import datasets, clone
from sklearn import manifold
from sklearn import neighbors
from sklearn import pipeline
from sklearn import preprocessing
from sklearn.datasets import make_blobs
from sklearn.metrics.pairwise import pairwise_distances
from sklearn.utils._testing import assert_allclose, assert_allclose_dense_sparse

from sklearn.utils._testing import (
assert_allclose,
assert_allclose_dense_sparse,
assert_array_equal,
)
from scipy.sparse import rand as sparse_rand

eigen_solvers = ["auto", "dense", "arpack"]
path_methods = ["auto", "FW", "D"]


def create_sample_data(n_pts=25, add_noise=False):
def create_sample_data(dtype, n_pts=25, add_noise=False):
# grid of equidistant points in 2D, n_components = n_dim
n_per_side = int(math.sqrt(n_pts))
X = np.array(list(product(range(n_per_side), repeat=2)))
X = np.array(list(product(range(n_per_side), repeat=2))).astype(dtype, copy=False)
if add_noise:
# add noise in a third dimension
rng = np.random.RandomState(0)
noise = 0.1 * rng.randn(n_pts, 1)
noise = 0.1 * rng.randn(n_pts, 1).astype(dtype, copy=False)
X = np.concatenate((X, noise), 1)
return X


@pytest.mark.parametrize("n_neighbors, radius", [(24, None), (None, np.inf)])
def test_isomap_simple_grid(n_neighbors, radius):
@pytest.mark.parametrize("eigen_solver", eigen_solvers)
@pytest.mark.parametrize("path_method", path_methods)
def test_isomap_simple_grid(
global_dtype, n_neighbors, radius, eigen_solver, path_method
):
# Isomap should preserve distances when all neighbors are used
n_pts = 25
X = create_sample_data(n_pts=n_pts, add_noise=False)
X = create_sample_data(global_dtype, n_pts=n_pts, add_noise=False)

# distances from each point to all others
if n_neighbors is not None:
G = neighbors.kneighbors_graph(X, n_neighbors, mode="distance")
else:
G = neighbors.radius_neighbors_graph(X, radius, mode="distance")

for eigen_solver in eigen_solvers:
for path_method in path_methods:
clf = manifold.Isomap(
n_neighbors=n_neighbors,
radius=radius,
n_components=2,
eigen_solver=eigen_solver,
path_method=path_method,
)
clf.fit(X)

if n_neighbors is not None:
G_iso = neighbors.kneighbors_graph(
clf.embedding_, n_neighbors, mode="distance"
)
else:
G_iso = neighbors.radius_neighbors_graph(
clf.embedding_, radius, mode="distance"
)
assert_allclose_dense_sparse(G, G_iso)
clf = manifold.Isomap(
n_neighbors=n_neighbors,
radius=radius,
n_components=2,
eigen_solver=eigen_solver,
path_method=path_method,
)
clf.fit(X)

if n_neighbors is not None:
G_iso = neighbors.kneighbors_graph(clf.embedding_, n_neighbors, mode="distance")
else:
G_iso = neighbors.radius_neighbors_graph(
clf.embedding_, radius, mode="distance"
)
atol = 1e-5 if global_dtype == np.float32 else 0
assert_allclose_dense_sparse(G, G_iso, atol=atol)


@pytest.mark.parametrize("n_neighbors, radius", [(24, None), (None, np.inf)])
def test_isomap_reconstruction_error(n_neighbors, radius):
@pytest.mark.parametrize("eigen_solver", eigen_solvers)
@pytest.mark.parametrize("path_method", path_methods)
def test_isomap_reconstruction_error(
global_dtype, n_neighbors, radius, eigen_solver, path_method
):
if global_dtype is np.float32:
pytest.skip(
"Skipping test due to numerical instabilities on float32 data"
"from KernelCenterer used in the reconstruction_error method"
)

# Same setup as in test_isomap_simple_grid, with an added dimension
n_pts = 25
X = create_sample_data(n_pts=n_pts, add_noise=True)
X = create_sample_data(global_dtype, n_pts=n_pts, add_noise=True)

# compute input kernel
if n_neighbors is not None:
Expand All @@ -83,43 +92,42 @@ def test_isomap_reconstruction_error(n_neighbors, radius):
centerer = preprocessing.KernelCenterer()
K = centerer.fit_transform(-0.5 * G**2)

for eigen_solver in eigen_solvers:
for path_method in path_methods:
clf = manifold.Isomap(
n_neighbors=n_neighbors,
radius=radius,
n_components=2,
eigen_solver=eigen_solver,
path_method=path_method,
)
clf.fit(X)

# compute output kernel
if n_neighbors is not None:
G_iso = neighbors.kneighbors_graph(
clf.embedding_, n_neighbors, mode="distance"
)
else:
G_iso = neighbors.radius_neighbors_graph(
clf.embedding_, radius, mode="distance"
)
G_iso = G_iso.toarray()
K_iso = centerer.fit_transform(-0.5 * G_iso**2)

# make sure error agrees
reconstruction_error = np.linalg.norm(K - K_iso) / n_pts
assert_almost_equal(reconstruction_error, clf.reconstruction_error())
clf = manifold.Isomap(
n_neighbors=n_neighbors,
radius=radius,
n_components=2,
eigen_solver=eigen_solver,
path_method=path_method,
)
clf.fit(X)

# compute output kernel
if n_neighbors is not None:
G_iso = neighbors.kneighbors_graph(clf.embedding_, n_neighbors, mode="distance")
else:
G_iso = neighbors.radius_neighbors_graph(
clf.embedding_, radius, mode="distance"
)
G_iso = G_iso.toarray()
K_iso = centerer.fit_transform(-0.5 * G_iso**2)

# make sure error agrees
reconstruction_error = np.linalg.norm(K - K_iso) / n_pts
atol = 1e-5 if global_dtype == np.float32 else 0
assert_allclose(reconstruction_error, clf.reconstruction_error(), atol=atol)


@pytest.mark.parametrize("n_neighbors, radius", [(2, None), (None, 0.5)])
def test_transform(n_neighbors, radius):
def test_transform(global_dtype, n_neighbors, radius):
n_samples = 200
n_components = 10
noise_scale = 0.01

# Create S-curve dataset
X, y = datasets.make_s_curve(n_samples, random_state=0)

X = X.astype(global_dtype, copy=False)

# Compute isomap embedding
iso = manifold.Isomap(
n_components=n_components, n_neighbors=n_neighbors, radius=radius
Expand All @@ -136,11 +144,12 @@ def test_transform(n_neighbors, radius):


@pytest.mark.parametrize("n_neighbors, radius", [(2, None), (None, 10.0)])
def test_pipeline(n_neighbors, radius):
def test_pipeline(n_neighbors, radius, global_dtype):
# check that Isomap works fine as a transformer in a Pipeline
# only checks that no error is raised.
# TODO check that it actually does something useful
X, y = datasets.make_blobs(random_state=0)
X = X.astype(global_dtype, copy=False)
clf = pipeline.Pipeline(
[
("isomap", manifold.Isomap(n_neighbors=n_neighbors, radius=radius)),
Expand All @@ -151,7 +160,7 @@ def test_pipeline(n_neighbors, radius):
assert 0.9 < clf.score(X, y)


def test_pipeline_with_nearest_neighbors_transformer():
def test_pipeline_with_nearest_neighbors_transformer(global_dtype):
# Test chaining NearestNeighborsTransformer and Isomap with
# neighbors_algorithm='precomputed'
algorithm = "auto"
Expand All @@ -160,6 +169,9 @@ def test_pipeline_with_nearest_neighbors_transformer():
X, _ = datasets.make_blobs(random_state=0)
X2, _ = datasets.make_blobs(random_state=1)

X = X.astype(global_dtype, copy=False)
X2 = X2.astype(global_dtype, copy=False)

# compare the chained version and the compact version
est_chain = pipeline.make_pipeline(
neighbors.KNeighborsTransformer(
Expand All @@ -173,38 +185,37 @@ def test_pipeline_with_nearest_neighbors_transformer():

Xt_chain = est_chain.fit_transform(X)
Xt_compact = est_compact.fit_transform(X)
assert_array_almost_equal(Xt_chain, Xt_compact)
assert_allclose(Xt_chain, Xt_compact)

Xt_chain = est_chain.transform(X2)
Xt_compact = est_compact.transform(X2)
assert_array_almost_equal(Xt_chain, Xt_compact)
assert_allclose(Xt_chain, Xt_compact)


def test_different_metric():
# Test that the metric parameters work correctly, and default to euclidean
def custom_metric(x1, x2):
return np.sqrt(np.sum(x1**2 + x2**2))

# metric, p, is_euclidean
metrics = [
@pytest.mark.parametrize(
"metric, p, is_euclidean",
[
("euclidean", 2, True),
("manhattan", 1, False),
("minkowski", 1, False),
("minkowski", 2, True),
(custom_metric, 2, False),
]

(lambda x1, x2: np.sqrt(np.sum(x1**2 + x2**2)), 2, False),
],
)
def test_different_metric(global_dtype, metric, p, is_euclidean):
# Isomap must work on various metric parameters work correctly
# and must default to euclidean.
X, _ = datasets.make_blobs(random_state=0)
reference = manifold.Isomap().fit_transform(X)
X = X.astype(global_dtype, copy=False)

for metric, p, is_euclidean in metrics:
embedding = manifold.Isomap(metric=metric, p=p).fit_transform(X)
reference = manifold.Isomap().fit_transform(X)
embedding = manifold.Isomap(metric=metric, p=p).fit_transform(X)

if is_euclidean:
assert_array_almost_equal(embedding, reference)
else:
with pytest.raises(AssertionError, match="not almost equal"):
assert_array_almost_equal(embedding, reference)
if is_euclidean:
assert_allclose(embedding, reference)
else:
with pytest.raises(AssertionError, match="Not equal to tolerance"):
assert_allclose(embedding, reference)


def test_isomap_clone_bug():
Expand All @@ -218,26 +229,38 @@ def test_isomap_clone_bug():

@pytest.mark.parametrize("eigen_solver", eigen_solvers)
@pytest.mark.parametrize("path_method", path_methods)
def test_sparse_input(eigen_solver, path_method):
def test_sparse_input(global_dtype, eigen_solver, path_method, global_random_seed):
# TODO: compare results on dense and sparse data as proposed in:
# https://github.com/scikit-learn/scikit-learn/pull/23585#discussion_r968388186
X = sparse_rand(100, 3, density=0.1, format="csr")
X = sparse_rand(
100,
3,
density=0.1,
format="csr",
dtype=global_dtype,
random_state=global_random_seed,
)

clf = manifold.Isomap(
iso_dense = manifold.Isomap(
n_components=2,
eigen_solver=eigen_solver,
path_method=path_method,
n_neighbors=8,
)
clf.fit(X)
clf.transform(X)
iso_sparse = clone(iso_dense)

X_trans_dense = iso_dense.fit_transform(X.toarray())
X_trans_sparse = iso_sparse.fit_transform(X)

assert_allclose(X_trans_sparse, X_trans_dense, rtol=1e-4, atol=1e-4)

def test_isomap_fit_precomputed_radius_graph():

def test_isomap_fit_precomputed_radius_graph(global_dtype):
# Isomap.fit_transform must yield similar result when using
# a precomputed distance matrix.

X, y = datasets.make_s_curve(200, random_state=0)
X = X.astype(global_dtype, copy=False)
radius = 10

g = neighbors.radius_neighbors_graph(X, radius=radius, mode="distance")
Expand All @@ -247,7 +270,8 @@ def test_isomap_fit_precomputed_radius_graph():

isomap = manifold.Isomap(n_neighbors=None, radius=radius, metric="minkowski")
result = isomap.fit_transform(X)
assert_allclose(precomputed_result, result)
atol = 1e-5 if global_dtype == np.float32 else 0
assert_allclose(precomputed_result, result, atol=atol)


def test_isomap_fitted_attributes_dtype(global_dtype):
Expand Down Expand Up @@ -294,10 +318,10 @@ def test_multiple_connected_components():
manifold.Isomap(n_neighbors=2).fit(X)


def test_multiple_connected_components_metric_precomputed():
def test_multiple_connected_components_metric_precomputed(global_dtype):
# Test that an error is raised when the graph has multiple components
# and when X is a precomputed neighbors graph.
X = np.array([0, 1, 2, 5, 6, 7])[:, None]
X = np.array([0, 1, 2, 5, 6, 7])[:, None].astype(global_dtype, copy=False)

# works with a precomputed distance matrix (dense)
X_distances = pairwise_distances(X)
Expand Down