Skip to content

Commit 21e1642

Browse files
authored
Replaced unsigned char with uint8_t in Cython code (#29597)
1 parent 45cf8ec commit 21e1642

File tree

9 files changed

+62
-58
lines changed

9 files changed

+62
-58
lines changed

sklearn/ensemble/_hist_gradient_boosting/_binning.pyx

+6-5
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ from cython.parallel import prange
44
from libc.math cimport isnan
55

66
from .common cimport X_DTYPE_C, X_BINNED_DTYPE_C
7+
from ...utils._typedefs cimport uint8_t
78

89

910
def _map_to_bins(const X_DTYPE_C [:, :] data,
1011
list binning_thresholds,
11-
const unsigned char[::1] is_categorical,
12-
const unsigned char missing_values_bin_idx,
12+
const uint8_t[::1] is_categorical,
13+
const uint8_t missing_values_bin_idx,
1314
int n_threads,
1415
X_BINNED_DTYPE_C [::1, :] binned):
1516
"""Bin continuous and categorical values to discrete integer-coded levels.
@@ -24,7 +25,7 @@ def _map_to_bins(const X_DTYPE_C [:, :] data,
2425
binning_thresholds : list of arrays
2526
For each feature, stores the increasing numeric values that are
2627
used to separate the bins.
27-
is_categorical : ndarray of unsigned char of shape (n_features,)
28+
is_categorical : ndarray of uint8_t of shape (n_features,)
2829
Indicates categorical features.
2930
n_threads : int
3031
Number of OpenMP threads to use.
@@ -48,8 +49,8 @@ def _map_to_bins(const X_DTYPE_C [:, :] data,
4849
cdef void _map_col_to_bins(
4950
const X_DTYPE_C [:] data,
5051
const X_DTYPE_C [:] binning_thresholds,
51-
const unsigned char is_categorical,
52-
const unsigned char missing_values_bin_idx,
52+
const uint8_t is_categorical,
53+
const uint8_t missing_values_bin_idx,
5354
int n_threads,
5455
X_BINNED_DTYPE_C [:] binned
5556
):

sklearn/ensemble/_hist_gradient_boosting/_bitset.pxd

+7-5
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,19 @@ from .common cimport X_BINNED_DTYPE_C
22
from .common cimport BITSET_DTYPE_C
33
from .common cimport BITSET_INNER_DTYPE_C
44
from .common cimport X_DTYPE_C
5+
from ...utils._typedefs cimport uint8_t
6+
57

68
cdef void init_bitset(BITSET_DTYPE_C bitset) noexcept nogil
79

810
cdef void set_bitset(BITSET_DTYPE_C bitset, X_BINNED_DTYPE_C val) noexcept nogil
911

10-
cdef unsigned char in_bitset(BITSET_DTYPE_C bitset, X_BINNED_DTYPE_C val) noexcept nogil
12+
cdef uint8_t in_bitset(BITSET_DTYPE_C bitset, X_BINNED_DTYPE_C val) noexcept nogil
1113

12-
cpdef unsigned char in_bitset_memoryview(const BITSET_INNER_DTYPE_C[:] bitset,
13-
X_BINNED_DTYPE_C val) noexcept nogil
14+
cpdef uint8_t in_bitset_memoryview(const BITSET_INNER_DTYPE_C[:] bitset,
15+
X_BINNED_DTYPE_C val) noexcept nogil
1416

15-
cdef unsigned char in_bitset_2d_memoryview(
16-
const BITSET_INNER_DTYPE_C [:, :] bitset,
17+
cdef uint8_t in_bitset_2d_memoryview(
18+
const BITSET_INNER_DTYPE_C[:, :] bitset,
1719
X_BINNED_DTYPE_C val,
1820
unsigned int row) noexcept nogil

sklearn/ensemble/_hist_gradient_boosting/_bitset.pyx

+8-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ from .common cimport BITSET_INNER_DTYPE_C
22
from .common cimport BITSET_DTYPE_C
33
from .common cimport X_DTYPE_C
44
from .common cimport X_BINNED_DTYPE_C
5+
from ...utils._typedefs cimport uint8_t
56

67

78
# A bitset is a data structure used to represent sets of integers in [0, n]. We
@@ -25,20 +26,19 @@ cdef inline void set_bitset(BITSET_DTYPE_C bitset, # OUT
2526
bitset[val // 32] |= (1 << (val % 32))
2627

2728

28-
cdef inline unsigned char in_bitset(BITSET_DTYPE_C bitset,
29-
X_BINNED_DTYPE_C val) noexcept nogil:
30-
29+
cdef inline uint8_t in_bitset(BITSET_DTYPE_C bitset,
30+
X_BINNED_DTYPE_C val) noexcept nogil:
3131
return (bitset[val // 32] >> (val % 32)) & 1
3232

3333

34-
cpdef inline unsigned char in_bitset_memoryview(const BITSET_INNER_DTYPE_C[:] bitset,
35-
X_BINNED_DTYPE_C val) noexcept nogil:
34+
cpdef inline uint8_t in_bitset_memoryview(const BITSET_INNER_DTYPE_C[:] bitset,
35+
X_BINNED_DTYPE_C val) noexcept nogil:
3636
return (bitset[val // 32] >> (val % 32)) & 1
3737

38-
cdef inline unsigned char in_bitset_2d_memoryview(const BITSET_INNER_DTYPE_C [:, :] bitset,
39-
X_BINNED_DTYPE_C val,
40-
unsigned int row) noexcept nogil:
4138

39+
cdef inline uint8_t in_bitset_2d_memoryview(const BITSET_INNER_DTYPE_C[:, :] bitset,
40+
X_BINNED_DTYPE_C val,
41+
unsigned int row) noexcept nogil:
4242
# Same as above but works on 2d memory views to avoid the creation of 1d
4343
# memory views. See https://github.com/scikit-learn/scikit-learn/issues/17299
4444
return (bitset[row, val // 32] >> (val % 32)) & 1

sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx

+3-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ from cython.parallel import prange
44
from libc.math cimport isnan
55
import numpy as np
66

7-
from ...utils._typedefs cimport intp_t
7+
from ...utils._typedefs cimport intp_t, uint8_t
88
from .common cimport X_DTYPE_C
99
from .common cimport Y_DTYPE_C
1010
from .common import Y_DTYPE
@@ -89,7 +89,7 @@ def _predict_from_binned_data(
8989
node_struct [:] nodes,
9090
const X_BINNED_DTYPE_C [:, :] binned_data,
9191
BITSET_INNER_DTYPE_C [:, :] binned_left_cat_bitsets,
92-
const unsigned char missing_values_bin_idx,
92+
const uint8_t missing_values_bin_idx,
9393
int n_threads,
9494
Y_DTYPE_C [:] out):
9595

@@ -109,7 +109,7 @@ cdef inline Y_DTYPE_C _predict_one_from_binned_data(
109109
const X_BINNED_DTYPE_C [:, :] binned_data,
110110
const BITSET_INNER_DTYPE_C [:, :] binned_left_cat_bitsets,
111111
const int row,
112-
const unsigned char missing_values_bin_idx) noexcept nogil:
112+
const uint8_t missing_values_bin_idx) noexcept nogil:
113113
# Need to pass the whole array and the row index, else prange won't work.
114114
# See issue Cython #2798
115115

sklearn/ensemble/_hist_gradient_boosting/common.pxd

+3-3
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ cdef packed struct node_struct:
2424
unsigned int count
2525
intp_t feature_idx
2626
X_DTYPE_C num_threshold
27-
unsigned char missing_go_to_left
27+
uint8_t missing_go_to_left
2828
unsigned int left
2929
unsigned int right
3030
Y_DTYPE_C gain
3131
unsigned int depth
32-
unsigned char is_leaf
32+
uint8_t is_leaf
3333
X_BINNED_DTYPE_C bin_threshold
34-
unsigned char is_categorical
34+
uint8_t is_categorical
3535
# The index of the corresponding bitsets in the Predictor's bitset arrays.
3636
# Only used if is_categorical is True
3737
unsigned int bitset_idx

sklearn/ensemble/_hist_gradient_boosting/histogram.pyx

+5-4
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ from .common import HISTOGRAM_DTYPE
1212
from .common cimport hist_struct
1313
from .common cimport X_BINNED_DTYPE_C
1414
from .common cimport G_H_DTYPE_C
15+
from ...utils._typedefs cimport uint8_t
1516

1617

1718
# Notes:
@@ -79,13 +80,13 @@ cdef class HistogramBuilder:
7980
G_H_DTYPE_C [::1] hessians
8081
G_H_DTYPE_C [::1] ordered_gradients
8182
G_H_DTYPE_C [::1] ordered_hessians
82-
unsigned char hessians_are_constant
83+
uint8_t hessians_are_constant
8384
int n_threads
8485

8586
def __init__(self, const X_BINNED_DTYPE_C [::1, :] X_binned,
8687
unsigned int n_bins, G_H_DTYPE_C [::1] gradients,
8788
G_H_DTYPE_C [::1] hessians,
88-
unsigned char hessians_are_constant,
89+
uint8_t hessians_are_constant,
8990
int n_threads):
9091

9192
self.X_binned = X_binned
@@ -130,7 +131,7 @@ cdef class HistogramBuilder:
130131
int f_idx
131132
int i
132133
# need local views to avoid python interactions
133-
unsigned char hessians_are_constant = self.hessians_are_constant
134+
uint8_t hessians_are_constant = self.hessians_are_constant
134135
int n_allowed_features = self.n_features
135136
G_H_DTYPE_C [::1] ordered_gradients = self.ordered_gradients
136137
G_H_DTYPE_C [::1] gradients = self.gradients
@@ -195,7 +196,7 @@ cdef class HistogramBuilder:
195196
self.ordered_gradients[:n_samples]
196197
G_H_DTYPE_C [::1] ordered_hessians = \
197198
self.ordered_hessians[:n_samples]
198-
unsigned char hessians_are_constant = \
199+
uint8_t hessians_are_constant = \
199200
self.hessians_are_constant
200201

201202
# Set histograms to zero.

sklearn/ensemble/_hist_gradient_boosting/splitting.pyx

+25-25
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ cdef struct split_info_struct:
3232
Y_DTYPE_C gain
3333
int feature_idx
3434
unsigned int bin_idx
35-
unsigned char missing_go_to_left
35+
uint8_t missing_go_to_left
3636
Y_DTYPE_C sum_gradient_left
3737
Y_DTYPE_C sum_gradient_right
3838
Y_DTYPE_C sum_hessian_left
@@ -41,7 +41,7 @@ cdef struct split_info_struct:
4141
unsigned int n_samples_right
4242
Y_DTYPE_C value_left
4343
Y_DTYPE_C value_right
44-
unsigned char is_categorical
44+
uint8_t is_categorical
4545
BITSET_DTYPE_C left_cat_bitset
4646

4747

@@ -168,11 +168,11 @@ cdef class Splitter:
168168
const X_BINNED_DTYPE_C [::1, :] X_binned
169169
unsigned int n_features
170170
const unsigned int [::1] n_bins_non_missing
171-
unsigned char missing_values_bin_idx
172-
const unsigned char [::1] has_missing_values
173-
const unsigned char [::1] is_categorical
171+
uint8_t missing_values_bin_idx
172+
const uint8_t [::1] has_missing_values
173+
const uint8_t [::1] is_categorical
174174
const signed char [::1] monotonic_cst
175-
unsigned char hessians_are_constant
175+
uint8_t hessians_are_constant
176176
Y_DTYPE_C l2_regularization
177177
Y_DTYPE_C min_hessian_to_split
178178
unsigned int min_samples_leaf
@@ -188,15 +188,15 @@ cdef class Splitter:
188188
def __init__(self,
189189
const X_BINNED_DTYPE_C [::1, :] X_binned,
190190
const unsigned int [::1] n_bins_non_missing,
191-
const unsigned char missing_values_bin_idx,
192-
const unsigned char [::1] has_missing_values,
193-
const unsigned char [::1] is_categorical,
191+
const uint8_t missing_values_bin_idx,
192+
const uint8_t [::1] has_missing_values,
193+
const uint8_t [::1] is_categorical,
194194
const signed char [::1] monotonic_cst,
195195
Y_DTYPE_C l2_regularization,
196196
Y_DTYPE_C min_hessian_to_split=1e-3,
197197
unsigned int min_samples_leaf=20,
198198
Y_DTYPE_C min_gain_to_split=0.,
199-
unsigned char hessians_are_constant=False,
199+
uint8_t hessians_are_constant=False,
200200
Y_DTYPE_C feature_fraction_per_split=1.0,
201201
rng=np.random.RandomState(),
202202
unsigned int n_threads=1):
@@ -307,14 +307,14 @@ cdef class Splitter:
307307
cdef:
308308
int n_samples = sample_indices.shape[0]
309309
X_BINNED_DTYPE_C bin_idx = split_info.bin_idx
310-
unsigned char missing_go_to_left = split_info.missing_go_to_left
311-
unsigned char missing_values_bin_idx = self.missing_values_bin_idx
310+
uint8_t missing_go_to_left = split_info.missing_go_to_left
311+
uint8_t missing_values_bin_idx = self.missing_values_bin_idx
312312
int feature_idx = split_info.feature_idx
313313
const X_BINNED_DTYPE_C [::1] X_binned = \
314314
self.X_binned[:, feature_idx]
315315
unsigned int [::1] left_indices_buffer = self.left_indices_buffer
316316
unsigned int [::1] right_indices_buffer = self.right_indices_buffer
317-
unsigned char is_categorical = split_info.is_categorical
317+
uint8_t is_categorical = split_info.is_categorical
318318
# Cython is unhappy if we set left_cat_bitset to
319319
# split_info.left_cat_bitset directly, so we need a tmp var
320320
BITSET_INNER_DTYPE_C [:] cat_bitset_tmp = split_info.left_cat_bitset
@@ -334,7 +334,7 @@ cdef class Splitter:
334334
int thread_idx
335335
int sample_idx
336336
int right_child_position
337-
unsigned char turn_left
337+
uint8_t turn_left
338338
int [:] left_offset = np.zeros(n_threads, dtype=np.int32)
339339
int [:] right_offset = np.zeros(n_threads, dtype=np.int32)
340340

@@ -482,8 +482,8 @@ cdef class Splitter:
482482
int n_allowed_features
483483
split_info_struct split_info
484484
split_info_struct * split_infos
485-
const unsigned char [::1] has_missing_values = self.has_missing_values
486-
const unsigned char [::1] is_categorical = self.is_categorical
485+
const uint8_t [::1] has_missing_values = self.has_missing_values
486+
const uint8_t [::1] is_categorical = self.is_categorical
487487
const signed char [::1] monotonic_cst = self.monotonic_cst
488488
int n_threads = self.n_threads
489489
bint has_interaction_cst = False
@@ -622,7 +622,7 @@ cdef class Splitter:
622622
cdef void _find_best_bin_to_split_left_to_right(
623623
Splitter self,
624624
unsigned int feature_idx,
625-
unsigned char has_missing_values,
625+
uint8_t has_missing_values,
626626
const hist_struct [:, ::1] histograms, # IN
627627
unsigned int n_samples,
628628
Y_DTYPE_C sum_gradients,
@@ -658,7 +658,7 @@ cdef class Splitter:
658658
Y_DTYPE_C sum_gradient_right
659659
Y_DTYPE_C loss_current_node
660660
Y_DTYPE_C gain
661-
unsigned char found_better_split = False
661+
uint8_t found_better_split = False
662662

663663
Y_DTYPE_C best_sum_hessian_left
664664
Y_DTYPE_C best_sum_gradient_left
@@ -771,7 +771,7 @@ cdef class Splitter:
771771
Y_DTYPE_C loss_current_node
772772
Y_DTYPE_C gain
773773
unsigned int start = self.n_bins_non_missing[feature_idx] - 2
774-
unsigned char found_better_split = False
774+
uint8_t found_better_split = False
775775

776776
Y_DTYPE_C best_sum_hessian_left
777777
Y_DTYPE_C best_sum_gradient_left
@@ -851,7 +851,7 @@ cdef class Splitter:
851851
cdef void _find_best_bin_to_split_category(
852852
self,
853853
unsigned int feature_idx,
854-
unsigned char has_missing_values,
854+
uint8_t has_missing_values,
855855
const hist_struct [:, ::1] histograms, # IN
856856
unsigned int n_samples,
857857
Y_DTYPE_C sum_gradients,
@@ -890,7 +890,7 @@ cdef class Splitter:
890890
unsigned int n_samples_left, n_samples_right
891891
Y_DTYPE_C gain
892892
Y_DTYPE_C best_gain = -1.0
893-
unsigned char found_better_split = False
893+
uint8_t found_better_split = False
894894
Y_DTYPE_C best_sum_hessian_left
895895
Y_DTYPE_C best_sum_gradient_left
896896
unsigned int best_n_samples_left
@@ -1139,12 +1139,12 @@ cdef inline Y_DTYPE_C _loss_from_value(
11391139
"""
11401140
return sum_gradient * value
11411141

1142-
cdef inline unsigned char sample_goes_left(
1143-
unsigned char missing_go_to_left,
1144-
unsigned char missing_values_bin_idx,
1142+
cdef inline uint8_t sample_goes_left(
1143+
uint8_t missing_go_to_left,
1144+
uint8_t missing_values_bin_idx,
11451145
X_BINNED_DTYPE_C split_bin_idx,
11461146
X_BINNED_DTYPE_C bin_value,
1147-
unsigned char is_categorical,
1147+
uint8_t is_categorical,
11481148
BITSET_DTYPE_C left_cat_bitset) noexcept nogil:
11491149
"""Helper to decide whether sample should go to left or right child."""
11501150

sklearn/linear_model/_sgd_fast.pyx.tp

+3-3
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ from cython cimport floating
3333
from libc.math cimport exp, fabs, isfinite, log, pow, INFINITY
3434

3535
from .._loss._loss cimport CyLossFunction
36-
from ..utils._typedefs cimport uint32_t
36+
from ..utils._typedefs cimport uint32_t, uint8_t
3737
from ..utils._weight_vector cimport WeightVector32, WeightVector64
3838
from ..utils._seq_dataset cimport SequentialDataset32, SequentialDataset64
3939

@@ -287,7 +287,7 @@ def _plain_sgd{{name_suffix}}(
287287
double C,
288288
double l1_ratio,
289289
SequentialDataset{{name_suffix}} dataset,
290-
const unsigned char[::1] validation_mask,
290+
const uint8_t[::1] validation_mask,
291291
bint early_stopping,
292292
validation_score_cb,
293293
int n_iter_no_change,
@@ -333,7 +333,7 @@ def _plain_sgd{{name_suffix}}(
333333
l1_ratio=0 corresponds to L2 penalty, l1_ratio=1 to L1.
334334
dataset : SequentialDataset
335335
A concrete ``SequentialDataset`` object.
336-
validation_mask : ndarray[unsigned char, ndim=1]
336+
validation_mask : ndarray[uint8_t, ndim=1]
337337
Equal to True on the validation set.
338338
early_stopping : boolean
339339
Whether to use a stopping criterion based on the validation set.

sklearn/metrics/_pairwise_distances_reduction/_radius_neighbors_classmode.pyx.tp

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ from cython cimport floating, final, integral
44
from cython.operator cimport dereference as deref
55
from cython.parallel cimport parallel, prange
66
from ._classmode cimport WeightingStrategy
7-
from ...utils._typedefs cimport intp_t, float64_t
7+
from ...utils._typedefs cimport intp_t, float64_t, uint8_t
88

99
import numpy as np
1010
from scipy.sparse import issparse
@@ -25,7 +25,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix}
2525
intp_t outlier_label_index
2626
bint outlier_label_exists
2727
bint outliers_exist
28-
unsigned char[::1] outliers
28+
uint8_t[::1] outliers
2929
object outlier_label
3030
float64_t[:, ::1] class_scores
3131
WeightingStrategy weight_type

0 commit comments

Comments
 (0)