@@ -33,54 +33,20 @@ def train_wrap(
33
33
cdef model * model
34
34
cdef char_const_ptr error_msg
35
35
cdef int len_w
36
-
37
- # The implementation for float32 and float64 uses a single interface.
38
- # This is done by accepting the data as a pointer to a buffer of bytes.
39
- # In this regard, we define a pointer to pass the address of the first
40
- # element of the buffer seen as raw bytes (hence the use of `char *`).
41
- #
42
- # We proceed in two steps using intermediate memory views to have Cython
43
- # have sufficient typing information not to use PyObjects.
44
- cdef cnp.float64_t[::1 ] X_data_64
45
- cdef cnp.float32_t[::1 ] X_data_32
46
- cdef char * X_data_as_bytes_ptr = NULL
47
-
48
- # The same is done for `indices` and `indptr` in the CSR case.
49
- cdef cnp.int32_t[::1 ] X_indices
50
- cdef char * X_indices_as_bytes_ptr = NULL
51
-
52
- cdef cnp.int32_t[::1 ] X_indptr
53
- cdef char * X_indptr_as_bytes_ptr = NULL
54
-
55
- cdef bint X_stores_float64_data = X.dtype == np.float64
36
+ cdef char * x_data_bytes_ptr
37
+ cdef cnp.int32_t[::1 ] x_indices
38
+ cdef cnp.int32_t[::1 ] x_indptr
39
+ cdef bint x_has_type_float64 = X.dtype == np.float64
56
40
57
41
if is_sparse:
58
- # X is a CSR matrix here, a format which stores the values
59
- # as a contiguous buffer via a NumPy array in a `data` attribute.
60
- # We get the address of the first element of the buffer which
61
- # we reference using a pointer to bytes.
62
- if X_stores_float64_data:
63
- X_data_64 = X.data
64
- X_data_as_bytes_ptr = < char * > & X_data_64[0 ]
65
- else :
66
- X_data_32 = X.data
67
- X_data_as_bytes_ptr = < char * > & X_data_32[0 ]
68
-
69
- # Similar operations are to be performed for `indices` and `indptr`.
70
- X_indices = X.indices
71
- X_indices_as_bytes_ptr = < char * > & X_indices[0 ]
72
-
73
- X_indptr = X.indptr
74
- X_indptr_as_bytes_ptr = < char * > & X_indptr[0 ]
75
-
42
+ x_data_bytes_ptr = _get_sparse_x_data_bytes(x = X, x_has_type_float64 = x_has_type_float64)
43
+ x_indices = X.indices
44
+ x_indptr = X.indptr
76
45
problem = csr_set_problem(
77
- # Underneath, the data will be statically re-interpreted as
78
- # either float32 or float64 depending on the boolean passed as
79
- # the second argument.
80
- X_data_as_bytes_ptr,
81
- X_stores_float64_data,
82
- X_indices_as_bytes_ptr,
83
- X_indptr_as_bytes_ptr,
46
+ x_data_bytes_ptr,
47
+ x_has_type_float64,
48
+ < char * > & x_indices[0 ],
49
+ < char * > & x_indptr[0 ],
84
50
(< cnp.int32_t> X.shape[0 ]),
85
51
(< cnp.int32_t> X.shape[1 ]),
86
52
(< cnp.int32_t> X.nnz),
@@ -89,17 +55,9 @@ def train_wrap(
89
55
< char * > & Y[0 ]
90
56
)
91
57
else :
92
- # X simply is a 2D NumPy array in this case.
93
- # This is reshapeable to a 1D NumPy array in O(1) (only strides are changed).
94
- if X_stores_float64_data:
95
- X_data_64 = X.reshape(- 1 )
96
- X_data_as_bytes_ptr = < char * > & X_data_64[0 ]
97
- else :
98
- X_data_32 = X.reshape(- 1 )
99
- X_data_as_bytes_ptr = < char * > & X_data_32[0 ]
100
-
58
+ x_data_bytes_ptr = _get_x_data_bytes(x = X, x_has_type_float64 = x_has_type_float64)
101
59
problem = set_problem(
102
- X_data_as_bytes_ptr ,
60
+ x_data_bytes_ptr ,
103
61
X.dtype == np.float64,
104
62
(< cnp.int32_t> X.shape[0 ]),
105
63
(< cnp.int32_t> X.shape[1 ]),
@@ -115,8 +73,8 @@ def train_wrap(
115
73
eps,
116
74
C,
117
75
class_weight.shape[0 ],
118
- < char * > & class_weight_label[0 ] if class_weight_label.size > 0 else NULL ,
119
- < char * > & class_weight[0 ] if class_weight.size > 0 else NULL ,
76
+ < char * > & class_weight_label[0 ] if class_weight_label.size > 0 else NULL ,
77
+ < char * > & class_weight[0 ] if class_weight.size > 0 else NULL ,
120
78
max_iter,
121
79
random_seed,
122
80
epsilon
@@ -168,6 +126,37 @@ def train_wrap(
168
126
return w.base, n_iter.base
169
127
170
128
129
+ cdef char * _get_sparse_x_data_bytes(object x, bint x_has_type_float64):
130
+ cdef cnp.float64_t[::1 ] x_data_64
131
+ cdef cnp.float32_t[::1 ] x_data_32
132
+ cdef char * x_data_bytes_ptr
133
+
134
+ if x_has_type_float64:
135
+ x_data_64 = x.data
136
+ x_data_bytes_ptr = < char * > & x_data_64[0 ]
137
+ else :
138
+ x_data_32 = x.data
139
+ x_data_bytes_ptr = < char * > & x_data_32[0 ]
140
+
141
+ return x_data_bytes_ptr
142
+
143
+
144
+ cdef char * _get_x_data_bytes(object x, bint x_has_type_float64):
145
+ cdef cnp.float64_t[::1 ] x_data_64
146
+ cdef cnp.float32_t[::1 ] x_data_32
147
+ cdef char * x_data_bytes_ptr
148
+
149
+ x_as_1d_array = x.reshape(- 1 )
150
+ if x_has_type_float64:
151
+ x_data_64 = x_as_1d_array
152
+ x_data_bytes_ptr = < char * > & x_data_64[0 ]
153
+ else :
154
+ x_data_32 = x_as_1d_array
155
+ x_data_bytes_ptr = < char * > & x_data_32[0 ]
156
+
157
+ return x_data_bytes_ptr
158
+
159
+
171
160
def set_verbosity_wrap (int verbosity ):
172
161
"""
173
162
Control verbosity of libsvm library
0 commit comments