Skip to content

Commit 407f456

Browse files
committed
* Remove the usage of tobytes()
* Use separate memory views for float64 and float32 to handle the possibly dtypes of X * Separate the functionality to get the bytes of X in functions for sparse and normal ndarray
1 parent d8b3ed6 commit 407f456

File tree

1 file changed

+46
-57
lines changed

1 file changed

+46
-57
lines changed

sklearn/svm/_liblinear.pyx

Lines changed: 46 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -33,54 +33,20 @@ def train_wrap(
3333
cdef model *model
3434
cdef char_const_ptr error_msg
3535
cdef int len_w
36-
37-
# The implementation for float32 and float64 uses a single interface.
38-
# This is done by accepting the data as a pointer to a buffer of bytes.
39-
# In this regard, we define a pointer to pass the address of the first
40-
# element of the buffer seen as raw bytes (hence the use of `char *`).
41-
#
42-
# We proceed in two steps using intermediate memory views to have Cython
43-
# have sufficient typing information not to use PyObjects.
44-
cdef cnp.float64_t[::1] X_data_64
45-
cdef cnp.float32_t[::1] X_data_32
46-
cdef char * X_data_as_bytes_ptr = NULL
47-
48-
# The same is done for `indices` and `indptr` in the CSR case.
49-
cdef cnp.int32_t[::1] X_indices
50-
cdef char * X_indices_as_bytes_ptr = NULL
51-
52-
cdef cnp.int32_t[::1] X_indptr
53-
cdef char * X_indptr_as_bytes_ptr = NULL
54-
55-
cdef bint X_stores_float64_data = X.dtype == np.float64
36+
cdef char * x_data_bytes_ptr
37+
cdef cnp.int32_t[::1] x_indices
38+
cdef cnp.int32_t[::1] x_indptr
39+
cdef bint x_has_type_float64 = X.dtype == np.float64
5640

5741
if is_sparse:
58-
# X is a CSR matrix here, a format which stores the values
59-
# as a contiguous buffer via a NumPy array in a `data` attribute.
60-
# We get the address of the first element of the buffer which
61-
# we reference using a pointer to bytes.
62-
if X_stores_float64_data:
63-
X_data_64 = X.data
64-
X_data_as_bytes_ptr = <char*> &X_data_64[0]
65-
else:
66-
X_data_32 = X.data
67-
X_data_as_bytes_ptr = <char*> &X_data_32[0]
68-
69-
# Similar operations are to be performed for `indices` and `indptr`.
70-
X_indices = X.indices
71-
X_indices_as_bytes_ptr = <char *> &X_indices[0]
72-
73-
X_indptr = X.indptr
74-
X_indptr_as_bytes_ptr = <char *> &X_indptr[0]
75-
42+
x_data_bytes_ptr = _get_sparse_x_data_bytes(x=X, x_has_type_float64=x_has_type_float64)
43+
x_indices = X.indices
44+
x_indptr = X.indptr
7645
problem = csr_set_problem(
77-
# Underneath, the data will be statically re-interpreted as
78-
# either float32 or float64 depending on the boolean passed as
79-
# the second argument.
80-
X_data_as_bytes_ptr,
81-
X_stores_float64_data,
82-
X_indices_as_bytes_ptr,
83-
X_indptr_as_bytes_ptr,
46+
x_data_bytes_ptr,
47+
x_has_type_float64,
48+
<char *> &x_indices[0],
49+
<char *> &x_indptr[0],
8450
(<cnp.int32_t>X.shape[0]),
8551
(<cnp.int32_t>X.shape[1]),
8652
(<cnp.int32_t>X.nnz),
@@ -89,17 +55,9 @@ def train_wrap(
8955
<char *> &Y[0]
9056
)
9157
else:
92-
# X simply is a 2D NumPy array in this case.
93-
# This is reshapeable to a 1D NumPy array in O(1) (only strides are changed).
94-
if X_stores_float64_data:
95-
X_data_64 = X.reshape(-1)
96-
X_data_as_bytes_ptr = <char*> &X_data_64[0]
97-
else:
98-
X_data_32 = X.reshape(-1)
99-
X_data_as_bytes_ptr = <char*> &X_data_32[0]
100-
58+
x_data_bytes_ptr = _get_x_data_bytes(x=X, x_has_type_float64=x_has_type_float64)
10159
problem = set_problem(
102-
X_data_as_bytes_ptr,
60+
x_data_bytes_ptr,
10361
X.dtype == np.float64,
10462
(<cnp.int32_t>X.shape[0]),
10563
(<cnp.int32_t>X.shape[1]),
@@ -115,8 +73,8 @@ def train_wrap(
11573
eps,
11674
C,
11775
class_weight.shape[0],
118-
<char*> &class_weight_label[0] if class_weight_label.size > 0 else NULL,
119-
<char*> &class_weight[0] if class_weight.size > 0 else NULL,
76+
<char *> &class_weight_label[0] if class_weight_label.size > 0 else NULL,
77+
<char *> &class_weight[0] if class_weight.size > 0 else NULL,
12078
max_iter,
12179
random_seed,
12280
epsilon
@@ -168,6 +126,37 @@ def train_wrap(
168126
return w.base, n_iter.base
169127

170128

129+
cdef char * _get_sparse_x_data_bytes(object x, bint x_has_type_float64):
130+
cdef cnp.float64_t[::1] x_data_64
131+
cdef cnp.float32_t[::1] x_data_32
132+
cdef char * x_data_bytes_ptr
133+
134+
if x_has_type_float64:
135+
x_data_64 = x.data
136+
x_data_bytes_ptr = <char *> &x_data_64[0]
137+
else:
138+
x_data_32 = x.data
139+
x_data_bytes_ptr = <char *> &x_data_32[0]
140+
141+
return x_data_bytes_ptr
142+
143+
144+
cdef char * _get_x_data_bytes(object x, bint x_has_type_float64):
145+
cdef cnp.float64_t[::1] x_data_64
146+
cdef cnp.float32_t[::1] x_data_32
147+
cdef char * x_data_bytes_ptr
148+
149+
x_as_1d_array = x.reshape(-1)
150+
if x_has_type_float64:
151+
x_data_64 = x_as_1d_array
152+
x_data_bytes_ptr = <char *> &x_data_64[0]
153+
else:
154+
x_data_32 = x_as_1d_array
155+
x_data_bytes_ptr = <char *> &x_data_32[0]
156+
157+
return x_data_bytes_ptr
158+
159+
171160
def set_verbosity_wrap(int verbosity):
172161
"""
173162
Control verbosity of libsvm library

0 commit comments

Comments
 (0)