Skip to content

Commit ecbe2d7

Browse files
jeremiedbbglemaitre
authored andcommitted
FIX attribute error is BIRCH (#23395)
1 parent bb6a0ef commit ecbe2d7

File tree

3 files changed

+75
-0
lines changed

3 files changed

+75
-0
lines changed

doc/whats_new/v1.1.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,23 @@
22

33
.. currentmodule:: sklearn
44

5+
.. _changes_1_1_2:
6+
7+
Version 1.1.2
8+
=============
9+
10+
**In Development**
11+
12+
Changelog
13+
---------
14+
15+
:mod:`sklearn.cluster`
16+
......................
17+
18+
- |Fix| Fixed a bug in :class:`cluster.Birch` that could trigger an error when splitting
19+
a node if there are duplicates in the dataset.
20+
:pr:`23395` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
21+
522
.. _changes_1_1_1:
623

724
Version 1.1.1

sklearn/cluster/_birch.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@ def _split_node(node, threshold, branching_factor):
8989
node1_dist, node2_dist = dist[(farthest_idx,)]
9090

9191
node1_closer = node1_dist < node2_dist
92+
# make sure node1 is closest to itself even if all distances are equal.
93+
# This can only happen when all node.centroids_ are duplicates leading to all
94+
# distances between centroids being zero.
95+
node1_closer[farthest_idx[0]] = True
96+
9297
for idx, subcluster in enumerate(node.subclusters_):
9398
if node1_closer[idx]:
9499
new_node1.append_subcluster(subcluster)

sklearn/cluster/tests/test_birch.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,56 @@ def test_feature_names_out():
228228

229229
names_out = brc.get_feature_names_out()
230230
assert_array_equal([f"birch{i}" for i in range(n_clusters)], names_out)
231+
232+
233+
def test_transform_match_across_dtypes():
234+
X, _ = make_blobs(n_samples=80, n_features=4, random_state=0)
235+
brc = Birch(n_clusters=4)
236+
Y_64 = brc.fit_transform(X)
237+
Y_32 = brc.fit_transform(X.astype(np.float32))
238+
239+
assert_allclose(Y_64, Y_32, atol=1e-6)
240+
241+
242+
def test_subcluster_dtype(global_dtype):
243+
X = make_blobs(n_samples=80, n_features=4, random_state=0)[0].astype(
244+
global_dtype, copy=False
245+
)
246+
brc = Birch(n_clusters=4)
247+
assert brc.fit(X).subcluster_centers_.dtype == global_dtype
248+
249+
250+
def test_both_subclusters_updated():
251+
"""Check that both subclusters are updated when a node a split, even when there are
252+
duplicated data points. Non-regression test for #23269.
253+
"""
254+
255+
X = np.array(
256+
[
257+
[-2.6192791, -1.5053215],
258+
[-2.9993038, -1.6863596],
259+
[-2.3724914, -1.3438171],
260+
[-2.336792, -1.3417323],
261+
[-2.4089134, -1.3290224],
262+
[-2.3724914, -1.3438171],
263+
[-3.364009, -1.8846745],
264+
[-2.3724914, -1.3438171],
265+
[-2.617677, -1.5003285],
266+
[-2.2960556, -1.3260119],
267+
[-2.3724914, -1.3438171],
268+
[-2.5459878, -1.4533926],
269+
[-2.25979, -1.3003055],
270+
[-2.4089134, -1.3290224],
271+
[-2.3724914, -1.3438171],
272+
[-2.4089134, -1.3290224],
273+
[-2.5459878, -1.4533926],
274+
[-2.3724914, -1.3438171],
275+
[-2.9720619, -1.7058647],
276+
[-2.336792, -1.3417323],
277+
[-2.3724914, -1.3438171],
278+
],
279+
dtype=np.float32,
280+
)
281+
282+
# no error
283+
Birch(branching_factor=5, threshold=1e-5, n_clusters=None).fit(X)

0 commit comments

Comments
 (0)