Skip to content

[MRG+1] Allows KMeans/MiniBatchKMeans to use float32 internally by using cython fused types #6846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 17 commits into from
Jun 21, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,11 @@ Enhancements
(`#6697 <https://github.com/scikit-learn/scikit-learn/pull/6697>`_) by
`Raghav R V`_.

- :class:`cluster.KMeans` and :class:`cluster.MiniBatchKMeans` now works
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to rebase and manage a conflict here.

with ``np.float32`` and ``np.float64`` input data without converting it.
This allows to reduce the memory consumption by using ``np.float32``.
(`#6846 <https://github.com/scikit-learn/scikit-learn/pull/6846>`_)
By `Sebastian Säger`_ and `YenChen Lin`_.

Bug fixes
.........
Expand Down Expand Up @@ -1769,7 +1774,7 @@ List of contributors for release 0.15 by number of commits.
* 4 Alexis Metaireau
* 4 Ignacio Rossi
* 4 Virgile Fritsch
* 4 Sebastian Saeger
* 4 Sebastian Säger
* 4 Ilambharathi Kanniah
* 4 sdenton4
* 4 Robert Layton
Expand Down Expand Up @@ -4266,4 +4271,8 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.

.. _Wenhua Yang: https://github.com/geekoala

.. _Arnaud Fouchet: https://github.com/afouchet
.. _Arnaud Fouchet: https://github.com/afouchet

.. _Sebastian Säger: https://github.com/ssaeger

.. _YenChen Lin: https://github.com/yenchenlin
102 changes: 71 additions & 31 deletions sklearn/cluster/_k_means.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import numpy as np
import scipy.sparse as sp
cimport numpy as np
cimport cython
from cython cimport floating

from ..utils.extmath import norm
from sklearn.utils.sparsefuncs_fast import assign_rows_csr
Expand All @@ -23,18 +24,19 @@ ctypedef np.int32_t INT

cdef extern from "cblas.h":
double ddot "cblas_ddot"(int N, double *X, int incX, double *Y, int incY)
float sdot "cblas_sdot"(int N, float *X, int incX, float *Y, int incY)

np.import_array()


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X,
np.ndarray[DOUBLE, ndim=1] x_squared_norms,
np.ndarray[DOUBLE, ndim=2] centers,
cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X,
np.ndarray[floating, ndim=1] x_squared_norms,
np.ndarray[floating, ndim=2] centers,
np.ndarray[INT, ndim=1] labels,
np.ndarray[DOUBLE, ndim=1] distances):
np.ndarray[floating, ndim=1] distances):
"""Compute label assignment and inertia for a dense array

Return the inertia (sum of squared distances to the centers).
Expand All @@ -43,33 +45,52 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X,
unsigned int n_clusters = centers.shape[0]
unsigned int n_features = centers.shape[1]
unsigned int n_samples = X.shape[0]
unsigned int x_stride = X.strides[1] / sizeof(DOUBLE)
unsigned int center_stride = centers.strides[1] / sizeof(DOUBLE)
unsigned int x_stride
unsigned int center_stride
unsigned int sample_idx, center_idx, feature_idx
unsigned int store_distances = 0
unsigned int k
np.ndarray[floating, ndim=1] center_squared_norms
# the following variables are always double cause make them floating
# does not save any memory, but makes the code much bigger
DOUBLE inertia = 0.0
DOUBLE min_dist
DOUBLE dist
np.ndarray[DOUBLE, ndim=1] center_squared_norms = np.zeros(
n_clusters, dtype=np.float64)

if floating is float:
center_squared_norms = np.zeros(n_clusters, dtype=np.float32)
x_stride = X.strides[1] / sizeof(float)
center_stride = centers.strides[1] / sizeof(float)
else:
center_squared_norms = np.zeros(n_clusters, dtype=np.float64)
x_stride = X.strides[1] / sizeof(DOUBLE)
center_stride = centers.strides[1] / sizeof(DOUBLE)

if n_samples == distances.shape[0]:
store_distances = 1

for center_idx in range(n_clusters):
center_squared_norms[center_idx] = ddot(
n_features, &centers[center_idx, 0], center_stride,
&centers[center_idx, 0], center_stride)
if floating is float:
center_squared_norms[center_idx] = sdot(
n_features, &centers[center_idx, 0], center_stride,
&centers[center_idx, 0], center_stride)
else:
center_squared_norms[center_idx] = ddot(
n_features, &centers[center_idx, 0], center_stride,
&centers[center_idx, 0], center_stride)

for sample_idx in range(n_samples):
min_dist = -1
for center_idx in range(n_clusters):
dist = 0.0
# hardcoded: minimize euclidean distance to cluster center:
# ||a - b||^2 = ||a||^2 + ||b||^2 -2 <a, b>
dist += ddot(n_features, &X[sample_idx, 0], x_stride,
&centers[center_idx, 0], center_stride)
if floating is float:
dist += sdot(n_features, &X[sample_idx, 0], x_stride,
&centers[center_idx, 0], center_stride)
else:
dist += ddot(n_features, &X[sample_idx, 0], x_stride,
&centers[center_idx, 0], center_stride)
dist *= -2
dist += center_squared_norms[center_idx]
dist += x_squared_norms[sample_idx]
Expand All @@ -88,15 +109,15 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X,
@cython.wraparound(False)
@cython.cdivision(True)
cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
np.ndarray[DOUBLE, ndim=2] centers,
np.ndarray[floating, ndim=2] centers,
np.ndarray[INT, ndim=1] labels,
np.ndarray[DOUBLE, ndim=1] distances):
np.ndarray[floating, ndim=1] distances):
"""Compute label assignment and inertia for a CSR input

Return the inertia (sum of squared distances to the centers).
"""
cdef:
np.ndarray[DOUBLE, ndim=1] X_data = X.data
np.ndarray[floating, ndim=1] X_data = X.data
np.ndarray[INT, ndim=1] X_indices = X.indices
np.ndarray[INT, ndim=1] X_indptr = X.indptr
unsigned int n_clusters = centers.shape[0]
Expand All @@ -105,18 +126,28 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
unsigned int store_distances = 0
unsigned int sample_idx, center_idx, feature_idx
unsigned int k
np.ndarray[floating, ndim=1] center_squared_norms
# the following variables are always double cause make them floating
# does not save any memory, but makes the code much bigger
DOUBLE inertia = 0.0
DOUBLE min_dist
DOUBLE dist
np.ndarray[DOUBLE, ndim=1] center_squared_norms = np.zeros(
n_clusters, dtype=np.float64)

if floating is float:
center_squared_norms = np.zeros(n_clusters, dtype=np.float32)
else:
center_squared_norms = np.zeros(n_clusters, dtype=np.float64)

if n_samples == distances.shape[0]:
store_distances = 1

for center_idx in range(n_clusters):
center_squared_norms[center_idx] = ddot(
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
if floating is float:
center_squared_norms[center_idx] = sdot(
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
else:
center_squared_norms[center_idx] = ddot(
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)

for sample_idx in range(n_samples):
min_dist = -1
Expand All @@ -143,17 +174,17 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
@cython.wraparound(False)
@cython.cdivision(True)
def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
np.ndarray[DOUBLE, ndim=2] centers,
np.ndarray[floating, ndim=2] centers,
np.ndarray[INT, ndim=1] counts,
np.ndarray[INT, ndim=1] nearest_center,
np.ndarray[DOUBLE, ndim=1] old_center,
np.ndarray[floating, ndim=1] old_center,
int compute_squared_diff):
"""Incremental update of the centers for sparse MiniBatchKMeans.

Parameters
----------

X: CSR matrix, dtype float64
X: CSR matrix, dtype float
The complete (pre allocated) training set as a CSR matrix.

centers: array, shape (n_clusters, n_features)
Expand All @@ -179,7 +210,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
of the algorithm.
"""
cdef:
np.ndarray[DOUBLE, ndim=1] X_data = X.data
np.ndarray[floating, ndim=1] X_data = X.data
np.ndarray[int, ndim=1] X_indices = X.indices
np.ndarray[int, ndim=1] X_indptr = X.indptr
unsigned int n_samples = X.shape[0]
Expand Down Expand Up @@ -245,9 +276,9 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _centers_dense(np.ndarray[DOUBLE, ndim=2] X,
def _centers_dense(np.ndarray[floating, ndim=2] X,
np.ndarray[INT, ndim=1] labels, int n_clusters,
np.ndarray[DOUBLE, ndim=1] distances):
np.ndarray[floating, ndim=1] distances):
"""M step of the K-means EM algorithm

Computation of cluster centers / means.
Expand Down Expand Up @@ -275,7 +306,12 @@ def _centers_dense(np.ndarray[DOUBLE, ndim=2] X,
n_samples = X.shape[0]
n_features = X.shape[1]
cdef int i, j, c
cdef np.ndarray[DOUBLE, ndim=2] centers = np.zeros((n_clusters, n_features))
cdef np.ndarray[floating, ndim=2] centers
if floating is float:
centers = np.zeros((n_clusters, n_features), dtype=np.float32)
else:
centers = np.zeros((n_clusters, n_features), dtype=np.float64)

n_samples_in_cluster = bincount(labels, minlength=n_clusters)
empty_clusters = np.where(n_samples_in_cluster == 0)[0]
# maybe also relocate small clusters?
Expand Down Expand Up @@ -303,7 +339,7 @@ def _centers_dense(np.ndarray[DOUBLE, ndim=2] X,
@cython.wraparound(False)
@cython.cdivision(True)
def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
np.ndarray[DOUBLE, ndim=1] distances):
np.ndarray[floating, ndim=1] distances):
"""M step of the K-means EM algorithm

Computation of cluster centers / means.
Expand All @@ -329,19 +365,23 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
cdef int n_features = X.shape[1]
cdef int curr_label

cdef np.ndarray[DOUBLE, ndim=1] data = X.data
cdef np.ndarray[floating, ndim=1] data = X.data
cdef np.ndarray[int, ndim=1] indices = X.indices
cdef np.ndarray[int, ndim=1] indptr = X.indptr

cdef np.ndarray[DOUBLE, ndim=2, mode="c"] centers = \
np.zeros((n_clusters, n_features))
cdef np.ndarray[floating, ndim=2, mode="c"] centers
cdef np.ndarray[np.npy_intp, ndim=1] far_from_centers
cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] n_samples_in_cluster = \
bincount(labels, minlength=n_clusters)
cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] empty_clusters = \
np.where(n_samples_in_cluster == 0)[0]
cdef int n_empty_clusters = empty_clusters.shape[0]

if floating is float:
centers = np.zeros((n_clusters, n_features), dtype=np.float32)
else:
centers = np.zeros((n_clusters, n_features), dtype=np.float64)

# maybe also relocate small clusters?

if n_empty_clusters > 0:
Expand Down
50 changes: 28 additions & 22 deletions sklearn/cluster/_k_means_elkan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import numpy as np
cimport numpy as np
cimport cython
from cython cimport floating

from libc.math cimport sqrt

Expand All @@ -18,8 +19,8 @@ from ._k_means import _centers_dense
from ..utils.fixes import partition


cdef double euclidian_dist(double* a, double* b, int n_features) nogil:
cdef double result, tmp
cdef floating euclidian_dist(floating* a, floating* b, int n_features) nogil:
cdef floating result, tmp
result = 0
cdef int i
for i in range(n_features):
Expand All @@ -29,8 +30,8 @@ cdef double euclidian_dist(double* a, double* b, int n_features) nogil:


cdef update_labels_distances_inplace(
double* X, double* centers, double[:, :] center_half_distances,
int[:] labels, double[:, :] lower_bounds, double[:] upper_bounds,
floating* X, floating* centers, floating[:, :] center_half_distances,
int[:] labels, floating[:, :] lower_bounds, floating[:] upper_bounds,
int n_samples, int n_features, int n_clusters):
"""
Calculate upper and lower bounds for each sample.
Expand Down Expand Up @@ -81,9 +82,9 @@ cdef update_labels_distances_inplace(
"""
# assigns closest center to X
# uses triangle inequality
cdef double* x
cdef double* c
cdef double d_c, dist
cdef floating* x
cdef floating* c
cdef floating d_c, dist
cdef int c_x, j, sample
for sample in range(n_samples):
# assign first cluster center
Expand All @@ -103,8 +104,8 @@ cdef update_labels_distances_inplace(
upper_bounds[sample] = d_c


def k_means_elkan(np.ndarray[np.float64_t, ndim=2, mode='c'] X_, int n_clusters,
np.ndarray[np.float64_t, ndim=2, mode='c'] init,
def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters,
np.ndarray[floating, ndim=2, mode='c'] init,
float tol=1e-4, int max_iter=30, verbose=False):
"""Run Elkan's k-means.

Expand All @@ -128,30 +129,35 @@ def k_means_elkan(np.ndarray[np.float64_t, ndim=2, mode='c'] X_, int n_clusters,
Whether to be verbose.

"""
#initialize
cdef np.ndarray[np.float64_t, ndim=2, mode='c'] centers_ = init
cdef double* centers_p = <double*>centers_.data
cdef double* X_p = <double*>X_.data
cdef double* x_p
if floating is float:
dtype = np.float32
else:
dtype = np.float64

#initialize
cdef np.ndarray[floating, ndim=2, mode='c'] centers_ = init
cdef floating* centers_p = <floating*>centers_.data
cdef floating* X_p = <floating*>X_.data
cdef floating* x_p
cdef Py_ssize_t n_samples = X_.shape[0]
cdef Py_ssize_t n_features = X_.shape[1]
cdef int point_index, center_index, label
cdef float upper_bound, distance
cdef double[:, :] center_half_distances = euclidean_distances(centers_) / 2.
cdef double[:, :] lower_bounds = np.zeros((n_samples, n_clusters))
cdef double[:] distance_next_center
cdef floating upper_bound, distance
cdef floating[:, :] center_half_distances = euclidean_distances(centers_) / 2.
cdef floating[:, :] lower_bounds = np.zeros((n_samples, n_clusters), dtype=dtype)
cdef floating[:] distance_next_center
labels_ = np.empty(n_samples, dtype=np.int32)
cdef int[:] labels = labels_
upper_bounds_ = np.empty(n_samples, dtype=np.float)
cdef double[:] upper_bounds = upper_bounds_
upper_bounds_ = np.empty(n_samples, dtype=dtype)
cdef floating[:] upper_bounds = upper_bounds_

# Get the inital set of upper bounds and lower bounds for each sample.
update_labels_distances_inplace(X_p, centers_p, center_half_distances,
labels, lower_bounds, upper_bounds,
n_samples, n_features, n_clusters)
cdef np.uint8_t[:] bounds_tight = np.ones(n_samples, dtype=np.uint8)
cdef np.uint8_t[:] points_to_update = np.zeros(n_samples, dtype=np.uint8)
cdef np.ndarray[np.float64_t, ndim=2, mode='c'] new_centers
cdef np.ndarray[floating, ndim=2, mode='c'] new_centers

if max_iter <= 0:
raise ValueError('Number of iterations should be a positive number'
Expand Down Expand Up @@ -226,7 +232,7 @@ def k_means_elkan(np.ndarray[np.float64_t, ndim=2, mode='c'] X_, int n_clusters,

# reassign centers
centers_ = new_centers
centers_p = <double*>new_centers.data
centers_p = <floating*>new_centers.data

# update between-center distances
center_half_distances = euclidean_distances(centers_) / 2.
Expand Down
Loading