Skip to content

TST use global_dtype in sklearn/cluster/tests/test_birch.py #22671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Nov 18, 2022
28 changes: 21 additions & 7 deletions sklearn/cluster/tests/test_birch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
from sklearn.utils._testing import assert_allclose


def test_n_samples_leaves_roots(global_random_seed):
def test_n_samples_leaves_roots(global_random_seed, global_dtype):
# Sanity check for the number of samples in leaves and roots
X, y = make_blobs(n_samples=10, random_state=global_random_seed)
X = X.astype(global_dtype, copy=False)
brc = Birch()
brc.fit(X)
n_samples_root = sum([sc.n_samples_ for sc in brc.root_.subclusters_])
Expand All @@ -30,9 +31,10 @@ def test_n_samples_leaves_roots(global_random_seed):
assert n_samples_root == X.shape[0]


def test_partial_fit(global_random_seed):
def test_partial_fit(global_random_seed, global_dtype):
# Test that fit is equivalent to calling partial_fit multiple times
X, y = make_blobs(n_samples=100, random_state=global_random_seed)
X = X.astype(global_dtype, copy=False)
brc = Birch(n_clusters=3)
brc.fit(X)
brc_partial = Birch(n_clusters=None)
Expand All @@ -47,17 +49,22 @@ def test_partial_fit(global_random_seed):
assert_array_equal(brc_partial.subcluster_labels_, brc.subcluster_labels_)


def test_birch_predict(global_random_seed):
def test_birch_predict(global_random_seed, global_dtype):
# Test the predict method predicts the nearest centroid.
rng = np.random.RandomState(global_random_seed)
X = generate_clustered_data(n_clusters=3, n_features=3, n_samples_per_cluster=10)
X = X.astype(global_dtype, copy=False)

# n_samples * n_samples_per_cluster
shuffle_indices = np.arange(30)
rng.shuffle(shuffle_indices)
X_shuffle = X[shuffle_indices, :]
brc = Birch(n_clusters=4, threshold=1.0)
brc.fit(X_shuffle)

# Birch must preserve inputs' dtype
assert brc.subcluster_centers_.dtype == global_dtype

assert_array_equal(brc.labels_, brc.predict(X_shuffle))
centroids = brc.subcluster_centers_
nearest_centroid = brc.subcluster_labels_[
Expand All @@ -66,9 +73,10 @@ def test_birch_predict(global_random_seed):
assert_allclose(v_measure_score(nearest_centroid, brc.labels_), 1.0)


def test_n_clusters(global_random_seed):
def test_n_clusters(global_random_seed, global_dtype):
# Test that n_clusters param works properly
X, y = make_blobs(n_samples=100, centers=10, random_state=global_random_seed)
X = X.astype(global_dtype, copy=False)
brc1 = Birch(n_clusters=10)
brc1.fit(X)
assert len(brc1.subcluster_centers_) > 10
Expand All @@ -88,16 +96,20 @@ def test_n_clusters(global_random_seed):
brc4.fit(X)


def test_sparse_X(global_random_seed):
def test_sparse_X(global_random_seed, global_dtype):
# Test that sparse and dense data give same results
X, y = make_blobs(n_samples=100, centers=10, random_state=global_random_seed)
X = X.astype(global_dtype, copy=False)
brc = Birch(n_clusters=10)
brc.fit(X)

csr = sparse.csr_matrix(X)
brc_sparse = Birch(n_clusters=10)
brc_sparse.fit(csr)

# Birch must preserve inputs' dtype
assert brc_sparse.subcluster_centers_.dtype == global_dtype

assert_array_equal(brc.labels_, brc_sparse.labels_)
assert_allclose(brc.subcluster_centers_, brc_sparse.subcluster_centers_)

Expand All @@ -122,9 +134,10 @@ def check_branching_factor(node, branching_factor):
check_branching_factor(cluster.child_, branching_factor)


def test_branching_factor(global_random_seed):
def test_branching_factor(global_random_seed, global_dtype):
# Test that nodes have at max branching_factor number of subclusters
X, y = make_blobs(random_state=global_random_seed)
X = X.astype(global_dtype, copy=False)
branching_factor = 9

# Purposefully set a low threshold to maximize the subclusters.
Expand All @@ -146,9 +159,10 @@ def check_threshold(birch_instance, threshold):
current_leaf = current_leaf.next_leaf_


def test_threshold(global_random_seed):
def test_threshold(global_random_seed, global_dtype):
# Test that the leaf subclusters have a threshold lesser than radius
X, y = make_blobs(n_samples=80, centers=4, random_state=global_random_seed)
X = X.astype(global_dtype, copy=False)
brc = Birch(threshold=0.5, n_clusters=None)
brc.fit(X)
check_threshold(brc, 0.5)
Expand Down