Skip to content

CLN Cleaned TreeUnionFind in _hdbscan/_tree.pyx #25827

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions sklearn/cluster/_hdbscan/_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -282,40 +282,37 @@ cdef max_lambdas(cnp.ndarray hierarchy):
return deaths


cdef class TreeUnionFind (object):
@cython.final
cdef class TreeUnionFind:

cdef cnp.ndarray _data_arr
cdef cnp.intp_t[:, ::1] _data
cdef cnp.ndarray is_component
cdef cnp.intp_t[:, ::1] data
cdef cnp.uint8_t[::1] is_component

def __init__(self, size):
self._data_arr = np.zeros((size, 2), dtype=np.intp)
self._data_arr.T[0] = np.arange(size)
self._data = self._data_arr
self.is_component = np.ones(size, dtype=bool)
cdef cnp.intp_t idx
self.data = np.zeros((size, 2), dtype=np.intp)
for idx in range(size):
self.data[idx, 0] = idx
self.is_component = np.ones(size, dtype=np.uint8)

cdef union_(self, cnp.intp_t x, cnp.intp_t y):
cdef void union(self, cnp.intp_t x, cnp.intp_t y):
cdef cnp.intp_t x_root = self.find(x)
cdef cnp.intp_t y_root = self.find(y)

if self._data[x_root, 1] < self._data[y_root, 1]:
self._data[x_root, 0] = y_root
elif self._data[x_root, 1] > self._data[y_root, 1]:
self._data[y_root, 0] = x_root
if self.data[x_root, 1] < self.data[y_root, 1]:
self.data[x_root, 0] = y_root
elif self.data[x_root, 1] > self.data[y_root, 1]:
self.data[y_root, 0] = x_root
else:
self._data[y_root, 0] = x_root
self._data[x_root, 1] += 1

self.data[y_root, 0] = x_root
self.data[x_root, 1] += 1
return

cdef find(self, cnp.intp_t x):
if self._data[x, 0] != x:
self._data[x, 0] = self.find(self._data[x, 0])
cdef cnp.intp_t find(self, cnp.intp_t x):
if self.data[x, 0] != x:
self.data[x, 0] = self.find(self.data[x, 0])
self.is_component[x] = False
return self._data[x, 0]

cdef cnp.ndarray[cnp.intp_t, ndim=1] components(self):
return self.is_component.nonzero()[0]
return self.data[x, 0]


cpdef cnp.ndarray[cnp.intp_t, ndim=1] labelling_at_cut(
Expand Down Expand Up @@ -361,8 +358,8 @@ cpdef cnp.ndarray[cnp.intp_t, ndim=1] labelling_at_cut(
cluster = n_samples
for row in linkage:
if row[2] < cut:
union_find.union_(<cnp.intp_t> row[0], cluster)
union_find.union_(<cnp.intp_t> row[1], cluster)
union_find.union(<cnp.intp_t> row[0], cluster)
union_find.union(<cnp.intp_t> row[1], cluster)
cluster += 1

cluster_size = np.zeros(cluster, dtype=np.intp)
Expand Down Expand Up @@ -416,7 +413,7 @@ cdef cnp.ndarray[cnp.intp_t, ndim=1] do_labelling(
child = child_array[n]
parent = parent_array[n]
if child not in clusters:
union_find.union_(parent, child)
union_find.union(parent, child)

for n in range(root_cluster):
cluster = union_find.find(n)
Expand Down