diff --git a/sklearn/linear_model/sgd_fast.pyx b/sklearn/linear_model/sgd_fast.pyx index ddea4b9710501..1b98e8c907f2b 100644 --- a/sklearn/linear_model/sgd_fast.pyx +++ b/sklearn/linear_model/sgd_fast.pyx @@ -22,7 +22,7 @@ from numpy.math cimport INFINITY cdef extern from "sgd_fast_helpers.h": bint skl_isfinite(double) nogil -from sklearn.utils.weight_vector cimport WeightVector +from sklearn.utils.weight_vector cimport WeightVector64 as WeightVector from sklearn.utils.seq_dataset cimport SequentialDataset64 as SequentialDataset np.import_array() diff --git a/sklearn/utils/setup.py b/sklearn/utils/setup.py index f3002ed3ffed9..8c13ffec21ecf 100644 --- a/sklearn/utils/setup.py +++ b/sklearn/utils/setup.py @@ -48,7 +48,10 @@ def configuration(parent_package='', top_path=None): # generate files from a template pyx_templates = ['sklearn/utils/seq_dataset.pyx.tp', - 'sklearn/utils/seq_dataset.pxd.tp'] + 'sklearn/utils/seq_dataset.pxd.tp', + 'sklearn/utils/weight_vector.pyx.tp', + 'sklearn/utils/weight_vector.pxd.tp', + ] for pyxfiles in pyx_templates: outfile = pyxfiles.replace('.tp', '') diff --git a/sklearn/utils/weight_vector.pxd b/sklearn/utils/weight_vector.pxd deleted file mode 100644 index 4ba4374c05e6c..0000000000000 --- a/sklearn/utils/weight_vector.pxd +++ /dev/null @@ -1,30 +0,0 @@ -# cython: language_level=3 -"""Efficient (dense) parameter vector implementation for linear models. """ - -cimport numpy as np - - -cdef extern from "math.h": - cdef extern double sqrt(double x) - - -cdef class WeightVector(object): - cdef np.ndarray w - cdef np.ndarray aw - cdef double *w_data_ptr - cdef double *aw_data_ptr - cdef double wscale - cdef double average_a - cdef double average_b - cdef int n_features - cdef double sq_norm - - cdef void add(self, double *x_data_ptr, int *x_ind_ptr, - int xnnz, double c) nogil - cdef void add_average(self, double *x_data_ptr, int *x_ind_ptr, - int xnnz, double c, double num_iter) nogil - cdef double dot(self, double *x_data_ptr, int *x_ind_ptr, - int xnnz) nogil - cdef void scale(self, double c) nogil - cdef void reset_wscale(self) nogil - cdef double norm(self) nogil diff --git a/sklearn/utils/weight_vector.pxd.tp b/sklearn/utils/weight_vector.pxd.tp new file mode 100644 index 0000000000000..144f84669afbd --- /dev/null +++ b/sklearn/utils/weight_vector.pxd.tp @@ -0,0 +1,56 @@ +{{py: + +""" +Efficient (dense) parameter vector implementation for linear models. + +Template file for easily generate fused types consistent code using Tempita +(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py). + +Generated file: weight_vector.pxd + +Each class is duplicated for all dtypes (float and double). The keywords +between double braces are substituted in setup.py. +""" + +# name, c_type +dtypes = [('64', 'double'), + ('32', 'float')] + +def get_dispatch(dtypes): + for name, c_type in dtypes: + yield name, c_type + +}} + +# cython: language_level=3 + +cimport numpy as np + +{{for name, c_type in get_dispatch(dtypes)}} + +cdef extern from "math.h": + cdef extern {{c_type}} sqrt({{c_type}} x) + + +cdef class WeightVector{{name}}(object): + cdef np.ndarray w + cdef np.ndarray aw + cdef {{c_type}} *w_data_ptr + cdef {{c_type}} *aw_data_ptr + cdef {{c_type}} wscale + cdef {{c_type}} average_a + cdef {{c_type}} average_b + cdef int n_features + cdef {{c_type}} sq_norm + + cdef void add(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, + int xnnz, {{c_type}} c) nogil + cdef void add_average(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, + int xnnz, {{c_type}} c, {{c_type}} num_iter) nogil + cdef {{c_type}} dot(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, + int xnnz) nogil + cdef void scale(self, {{c_type}} c) nogil + cdef void reset_wscale(self) nogil + cdef {{c_type}} norm(self) nogil + +{{endfor}} \ No newline at end of file diff --git a/sklearn/utils/weight_vector.pyx b/sklearn/utils/weight_vector.pyx.tp similarity index 65% rename from sklearn/utils/weight_vector.pyx rename to sklearn/utils/weight_vector.pyx.tp index 91c5273d210e4..0dbee605e9a3d 100644 --- a/sklearn/utils/weight_vector.pyx +++ b/sklearn/utils/weight_vector.pyx.tp @@ -1,3 +1,27 @@ +{{py: + +""" +Efficient (dense) parameter vector implementation for linear models. + +Template file for easily generate fused types consistent code using Tempita +(https://github.com/cython/cython/blob/master/Cython/Tempita/_tempita.py). + +Generated file: weight_vector.pxd + +Each class is duplicated for all dtypes (float and double). The keywords +between double braces are substituted in setup.py. +""" + +# name, c_type +dtypes = [('64', 'double'), + ('32', 'float')] + +def get_dispatch(dtypes): + for name, c_type in dtypes: + yield name, c_type + +}} + # cython: cdivision=True # cython: boundscheck=False # cython: wraparound=False @@ -21,7 +45,9 @@ from ._cython_blas cimport _dot, _scal, _axpy np.import_array() -cdef class WeightVector(object): +{{for name, c_type in get_dispatch(dtypes)}} + +cdef class WeightVector{{name}}(object): """Dense vector represented by a scalar and a numpy array. The class provides methods to ``add`` a sparse vector @@ -31,24 +57,24 @@ cdef class WeightVector(object): Attributes ---------- - w : ndarray, dtype=double, order='C' + w : ndarray, dtype={{c_type}}, order='C' The numpy array which backs the weight vector. - aw : ndarray, dtype=double, order='C' + aw : ndarray, dtype={{c_type}}, order='C' The numpy array which backs the average_weight vector. - w_data_ptr : double* + w_data_ptr : {{c_type}}* A pointer to the data of the numpy array. - wscale : double + wscale : {{c_type}} The scale of the vector. n_features : int The number of features (= dimensionality of ``w``). - sq_norm : double + sq_norm : {{c_type}} The squared norm of ``w``. """ def __cinit__(self, - np.ndarray[double, ndim=1, mode='c'] w, - np.ndarray[double, ndim=1, mode='c'] aw): - cdef double *wdata = w.data + np.ndarray[{{c_type}}, ndim=1, mode='c'] w, + np.ndarray[{{c_type}}, ndim=1, mode='c'] aw): + cdef {{c_type}} *wdata = <{{c_type}} *>w.data if w.shape[0] > INT_MAX: raise ValueError("More than %d features not supported; got %d." @@ -57,40 +83,40 @@ cdef class WeightVector(object): self.w_data_ptr = wdata self.wscale = 1.0 self.n_features = w.shape[0] - self.sq_norm = _dot(w.shape[0], wdata, 1, wdata, 1) + self.sq_norm = _dot(w.shape[0], wdata, 1, wdata, 1) self.aw = aw if self.aw is not None: - self.aw_data_ptr = aw.data + self.aw_data_ptr = <{{c_type}} *>aw.data self.average_a = 0.0 self.average_b = 1.0 - cdef void add(self, double *x_data_ptr, int *x_ind_ptr, int xnnz, - double c) nogil: + cdef void add(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, int xnnz, + {{c_type}} c) nogil: """Scales sample x by constant c and adds it to the weight vector. This operation updates ``sq_norm``. Parameters ---------- - x_data_ptr : double* + x_data_ptr : {{c_type}}* The array which holds the feature values of ``x``. x_ind_ptr : np.intc* The array which holds the feature indices of ``x``. xnnz : int The number of non-zero features of ``x``. - c : double + c : {{c_type}} The scaling constant for the example. """ cdef int j cdef int idx - cdef double val - cdef double innerprod = 0.0 - cdef double xsqnorm = 0.0 + cdef {{c_type}} val + cdef {{c_type}} innerprod = 0.0 + cdef {{c_type}} xsqnorm = 0.0 # the next two lines save a factor of 2! - cdef double wscale = self.wscale - cdef double* w_data_ptr = self.w_data_ptr + cdef {{c_type}} wscale = self.wscale + cdef {{c_type}}* w_data_ptr = self.w_data_ptr for j in range(xnnz): idx = x_ind_ptr[j] @@ -104,30 +130,30 @@ cdef class WeightVector(object): # Update the average weights according to the sparse trick defined # here: https://research.microsoft.com/pubs/192769/tricks-2012.pdf # by Leon Bottou - cdef void add_average(self, double *x_data_ptr, int *x_ind_ptr, int xnnz, - double c, double num_iter) nogil: + cdef void add_average(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, int xnnz, + {{c_type}} c, {{c_type}} num_iter) nogil: """Updates the average weight vector. Parameters ---------- - x_data_ptr : double* + x_data_ptr : {{c_type}}* The array which holds the feature values of ``x``. x_ind_ptr : np.intc* The array which holds the feature indices of ``x``. xnnz : int The number of non-zero features of ``x``. - c : double + c : {{c_type}} The scaling constant for the example. - num_iter : double + num_iter : {{c_type}} The total number of iterations. """ cdef int j cdef int idx - cdef double val - cdef double mu = 1.0 / num_iter - cdef double average_a = self.average_a - cdef double wscale = self.wscale - cdef double* aw_data_ptr = self.aw_data_ptr + cdef {{c_type}} val + cdef {{c_type}} mu = 1.0 / num_iter + cdef {{c_type}} average_a = self.average_a + cdef {{c_type}} wscale = self.wscale + cdef {{c_type}}* aw_data_ptr = self.aw_data_ptr for j in range(xnnz): idx = x_ind_ptr[j] @@ -140,13 +166,13 @@ cdef class WeightVector(object): self.average_b /= (1.0 - mu) self.average_a += mu * self.average_b * wscale - cdef double dot(self, double *x_data_ptr, int *x_ind_ptr, + cdef {{c_type}} dot(self, {{c_type}} *x_data_ptr, int *x_ind_ptr, int xnnz) nogil: """Computes the dot product of a sample x and the weight vector. Parameters ---------- - x_data_ptr : double* + x_data_ptr : {{c_type}}* The array which holds the feature values of ``x``. x_ind_ptr : np.intc* The array which holds the feature indices of ``x``. @@ -155,20 +181,20 @@ cdef class WeightVector(object): Returns ------- - innerprod : double + innerprod : {{c_type}} The inner product of ``x`` and ``w``. """ cdef int j cdef int idx - cdef double innerprod = 0.0 - cdef double* w_data_ptr = self.w_data_ptr + cdef {{c_type}} innerprod = 0.0 + cdef {{c_type}}* w_data_ptr = self.w_data_ptr for j in range(xnnz): idx = x_ind_ptr[j] innerprod += w_data_ptr[idx] * x_data_ptr[j] innerprod *= self.wscale return innerprod - cdef void scale(self, double c) nogil: + cdef void scale(self, {{c_type}} c) nogil: """Scales the weight vector by a constant ``c``. It updates ``wscale`` and ``sq_norm``. If ``wscale`` gets too @@ -182,15 +208,17 @@ cdef class WeightVector(object): """Scales each coef of ``w`` by ``wscale`` and resets it to 1. """ if self.aw is not None: _axpy(self.aw.shape[0], self.average_a, - self.w.data, 1, self.aw.data, 1) + <{{c_type}} *>self.w.data, 1, <{{c_type}} *>self.aw.data, 1) _scal(self.aw.shape[0], 1.0 / self.average_b, - self.aw.data, 1) + <{{c_type}} *>self.aw.data, 1) self.average_a = 0.0 self.average_b = 1.0 - _scal(self.w.shape[0], self.wscale, self.w.data, 1) + _scal(self.w.shape[0], self.wscale, <{{c_type}} *>self.w.data, 1) self.wscale = 1.0 - cdef double norm(self) nogil: + cdef {{c_type}} norm(self) nogil: """The L2 norm of the weight vector. """ return sqrt(self.sq_norm) + +{{endfor}} \ No newline at end of file