Skip to content

Commit 243e013

Browse files
committed
Use fused types in kmeans_elkan
1 parent 9567ef5 commit 243e013

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed

sklearn/cluster/_k_means_elkan.pyx

+28-22
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111
cimport numpy as np
1212
cimport cython
13+
from cython cimport floating
1314

1415
from libc.math cimport sqrt
1516

@@ -18,8 +19,8 @@ from ._k_means import _centers_dense
1819
from ..utils.fixes import partition
1920

2021

21-
cdef double euclidian_dist(double* a, double* b, int n_features) nogil:
22-
cdef double result, tmp
22+
cdef floating euclidian_dist(floating* a, floating* b, int n_features) nogil:
23+
cdef floating result, tmp
2324
result = 0
2425
cdef int i
2526
for i in range(n_features):
@@ -29,8 +30,8 @@ cdef double euclidian_dist(double* a, double* b, int n_features) nogil:
2930

3031

3132
cdef update_labels_distances_inplace(
32-
double* X, double* centers, double[:, :] center_half_distances,
33-
int[:] labels, double[:, :] lower_bounds, double[:] upper_bounds,
33+
floating* X, floating* centers, floating[:, :] center_half_distances,
34+
int[:] labels, floating[:, :] lower_bounds, floating[:] upper_bounds,
3435
int n_samples, int n_features, int n_clusters):
3536
"""
3637
Calculate upper and lower bounds for each sample.
@@ -81,9 +82,9 @@ cdef update_labels_distances_inplace(
8182
"""
8283
# assigns closest center to X
8384
# uses triangle inequality
84-
cdef double* x
85-
cdef double* c
86-
cdef double d_c, dist
85+
cdef floating* x
86+
cdef floating* c
87+
cdef floating d_c, dist
8788
cdef int c_x, j, sample
8889
for sample in range(n_samples):
8990
# assign first cluster center
@@ -103,8 +104,8 @@ cdef update_labels_distances_inplace(
103104
upper_bounds[sample] = d_c
104105

105106

106-
def k_means_elkan(np.ndarray[np.float64_t, ndim=2, mode='c'] X_, int n_clusters,
107-
np.ndarray[np.float64_t, ndim=2, mode='c'] init,
107+
def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters,
108+
np.ndarray[floating, ndim=2, mode='c'] init,
108109
float tol=1e-4, int max_iter=30, verbose=False):
109110
"""Run Elkan's k-means.
110111
@@ -128,30 +129,35 @@ def k_means_elkan(np.ndarray[np.float64_t, ndim=2, mode='c'] X_, int n_clusters,
128129
Whether to be verbose.
129130
130131
"""
131-
#initialize
132-
cdef np.ndarray[np.float64_t, ndim=2, mode='c'] centers_ = init
133-
cdef double* centers_p = <double*>centers_.data
134-
cdef double* X_p = <double*>X_.data
135-
cdef double* x_p
132+
if floating is float:
133+
dtype = np.float32
134+
else:
135+
dtype = np.float64
136+
137+
#initialize
138+
cdef np.ndarray[floating, ndim=2, mode='c'] centers_ = init
139+
cdef floating* centers_p = <floating*>centers_.data
140+
cdef floating* X_p = <floating*>X_.data
141+
cdef floating* x_p
136142
cdef Py_ssize_t n_samples = X_.shape[0]
137143
cdef Py_ssize_t n_features = X_.shape[1]
138144
cdef int point_index, center_index, label
139-
cdef float upper_bound, distance
140-
cdef double[:, :] center_half_distances = euclidean_distances(centers_) / 2.
141-
cdef double[:, :] lower_bounds = np.zeros((n_samples, n_clusters))
142-
cdef double[:] distance_next_center
145+
cdef floating upper_bound, distance
146+
cdef floating[:, :] center_half_distances = euclidean_distances(centers_) / 2.
147+
cdef floating[:, :] lower_bounds = np.zeros((n_samples, n_clusters), dtype=dtype)
148+
cdef floating[:] distance_next_center
143149
labels_ = np.empty(n_samples, dtype=np.int32)
144150
cdef int[:] labels = labels_
145-
upper_bounds_ = np.empty(n_samples, dtype=np.float)
146-
cdef double[:] upper_bounds = upper_bounds_
151+
upper_bounds_ = np.empty(n_samples, dtype=dtype)
152+
cdef floating[:] upper_bounds = upper_bounds_
147153

148154
# Get the inital set of upper bounds and lower bounds for each sample.
149155
update_labels_distances_inplace(X_p, centers_p, center_half_distances,
150156
labels, lower_bounds, upper_bounds,
151157
n_samples, n_features, n_clusters)
152158
cdef np.uint8_t[:] bounds_tight = np.ones(n_samples, dtype=np.uint8)
153159
cdef np.uint8_t[:] points_to_update = np.zeros(n_samples, dtype=np.uint8)
154-
cdef np.ndarray[np.float64_t, ndim=2, mode='c'] new_centers
160+
cdef np.ndarray[floating, ndim=2, mode='c'] new_centers
155161

156162
if max_iter <= 0:
157163
raise ValueError('Number of iterations should be a positive number'
@@ -226,7 +232,7 @@ def k_means_elkan(np.ndarray[np.float64_t, ndim=2, mode='c'] X_, int n_clusters,
226232

227233
# reassign centers
228234
centers_ = new_centers
229-
centers_p = <double*>new_centers.data
235+
centers_p = <floating*>new_centers.data
230236

231237
# update between-center distances
232238
center_half_distances = euclidean_distances(centers_) / 2.

0 commit comments

Comments
 (0)