Skip to content

Replaced unsigned char with uint8_t in Cython Codebase (9 files) #29597

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions sklearn/ensemble/_hist_gradient_boosting/_binning.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@ from cython.parallel import prange
from libc.math cimport isnan

from .common cimport X_DTYPE_C, X_BINNED_DTYPE_C
from ...utils._typedefs cimport uint8_t


def _map_to_bins(const X_DTYPE_C [:, :] data,
list binning_thresholds,
const unsigned char[::1] is_categorical,
const unsigned char missing_values_bin_idx,
const uint8_t[::1] is_categorical,
const uint8_t missing_values_bin_idx,
int n_threads,
X_BINNED_DTYPE_C [::1, :] binned):
"""Bin continuous and categorical values to discrete integer-coded levels.
Expand All @@ -24,7 +25,7 @@ def _map_to_bins(const X_DTYPE_C [:, :] data,
binning_thresholds : list of arrays
For each feature, stores the increasing numeric values that are
used to separate the bins.
is_categorical : ndarray of unsigned char of shape (n_features,)
is_categorical : ndarray of uint8_t of shape (n_features,)
Indicates categorical features.
n_threads : int
Number of OpenMP threads to use.
Expand All @@ -48,8 +49,8 @@ def _map_to_bins(const X_DTYPE_C [:, :] data,
cdef void _map_col_to_bins(
const X_DTYPE_C [:] data,
const X_DTYPE_C [:] binning_thresholds,
const unsigned char is_categorical,
const unsigned char missing_values_bin_idx,
const uint8_t is_categorical,
const uint8_t missing_values_bin_idx,
int n_threads,
X_BINNED_DTYPE_C [:] binned
):
Expand Down
12 changes: 7 additions & 5 deletions sklearn/ensemble/_hist_gradient_boosting/_bitset.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,19 @@ from .common cimport X_BINNED_DTYPE_C
from .common cimport BITSET_DTYPE_C
from .common cimport BITSET_INNER_DTYPE_C
from .common cimport X_DTYPE_C
from ...utils._typedefs cimport uint8_t


cdef void init_bitset(BITSET_DTYPE_C bitset) noexcept nogil

cdef void set_bitset(BITSET_DTYPE_C bitset, X_BINNED_DTYPE_C val) noexcept nogil

cdef unsigned char in_bitset(BITSET_DTYPE_C bitset, X_BINNED_DTYPE_C val) noexcept nogil
cdef uint8_t in_bitset(BITSET_DTYPE_C bitset, X_BINNED_DTYPE_C val) noexcept nogil

cpdef unsigned char in_bitset_memoryview(const BITSET_INNER_DTYPE_C[:] bitset,
X_BINNED_DTYPE_C val) noexcept nogil
cpdef uint8_t in_bitset_memoryview(const BITSET_INNER_DTYPE_C[:] bitset,
X_BINNED_DTYPE_C val) noexcept nogil

cdef unsigned char in_bitset_2d_memoryview(
const BITSET_INNER_DTYPE_C [:, :] bitset,
cdef uint8_t in_bitset_2d_memoryview(
const BITSET_INNER_DTYPE_C[:, :] bitset,
X_BINNED_DTYPE_C val,
unsigned int row) noexcept nogil
16 changes: 8 additions & 8 deletions sklearn/ensemble/_hist_gradient_boosting/_bitset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ from .common cimport BITSET_INNER_DTYPE_C
from .common cimport BITSET_DTYPE_C
from .common cimport X_DTYPE_C
from .common cimport X_BINNED_DTYPE_C
from ...utils._typedefs cimport uint8_t


# A bitset is a data structure used to represent sets of integers in [0, n]. We
Expand All @@ -25,20 +26,19 @@ cdef inline void set_bitset(BITSET_DTYPE_C bitset, # OUT
bitset[val // 32] |= (1 << (val % 32))


cdef inline unsigned char in_bitset(BITSET_DTYPE_C bitset,
X_BINNED_DTYPE_C val) noexcept nogil:

cdef inline uint8_t in_bitset(BITSET_DTYPE_C bitset,
X_BINNED_DTYPE_C val) noexcept nogil:
return (bitset[val // 32] >> (val % 32)) & 1


cpdef inline unsigned char in_bitset_memoryview(const BITSET_INNER_DTYPE_C[:] bitset,
X_BINNED_DTYPE_C val) noexcept nogil:
cpdef inline uint8_t in_bitset_memoryview(const BITSET_INNER_DTYPE_C[:] bitset,
X_BINNED_DTYPE_C val) noexcept nogil:
return (bitset[val // 32] >> (val % 32)) & 1

cdef inline unsigned char in_bitset_2d_memoryview(const BITSET_INNER_DTYPE_C [:, :] bitset,
X_BINNED_DTYPE_C val,
unsigned int row) noexcept nogil:

cdef inline uint8_t in_bitset_2d_memoryview(const BITSET_INNER_DTYPE_C[:, :] bitset,
X_BINNED_DTYPE_C val,
unsigned int row) noexcept nogil:
# Same as above but works on 2d memory views to avoid the creation of 1d
# memory views. See https://github.com/scikit-learn/scikit-learn/issues/17299
return (bitset[row, val // 32] >> (val % 32)) & 1
Expand Down
6 changes: 3 additions & 3 deletions sklearn/ensemble/_hist_gradient_boosting/_predictor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ from cython.parallel import prange
from libc.math cimport isnan
import numpy as np

from ...utils._typedefs cimport intp_t
from ...utils._typedefs cimport intp_t, uint8_t
from .common cimport X_DTYPE_C
from .common cimport Y_DTYPE_C
from .common import Y_DTYPE
Expand Down Expand Up @@ -89,7 +89,7 @@ def _predict_from_binned_data(
node_struct [:] nodes,
const X_BINNED_DTYPE_C [:, :] binned_data,
BITSET_INNER_DTYPE_C [:, :] binned_left_cat_bitsets,
const unsigned char missing_values_bin_idx,
const uint8_t missing_values_bin_idx,
int n_threads,
Y_DTYPE_C [:] out):

Expand All @@ -109,7 +109,7 @@ cdef inline Y_DTYPE_C _predict_one_from_binned_data(
const X_BINNED_DTYPE_C [:, :] binned_data,
const BITSET_INNER_DTYPE_C [:, :] binned_left_cat_bitsets,
const int row,
const unsigned char missing_values_bin_idx) noexcept nogil:
const uint8_t missing_values_bin_idx) noexcept nogil:
# Need to pass the whole array and the row index, else prange won't work.
# See issue Cython #2798

Expand Down
6 changes: 3 additions & 3 deletions sklearn/ensemble/_hist_gradient_boosting/common.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ cdef packed struct node_struct:
unsigned int count
intp_t feature_idx
X_DTYPE_C num_threshold
unsigned char missing_go_to_left
uint8_t missing_go_to_left
unsigned int left
unsigned int right
Y_DTYPE_C gain
unsigned int depth
unsigned char is_leaf
uint8_t is_leaf
X_BINNED_DTYPE_C bin_threshold
unsigned char is_categorical
uint8_t is_categorical
# The index of the corresponding bitsets in the Predictor's bitset arrays.
# Only used if is_categorical is True
unsigned int bitset_idx
Expand Down
9 changes: 5 additions & 4 deletions sklearn/ensemble/_hist_gradient_boosting/histogram.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ from .common import HISTOGRAM_DTYPE
from .common cimport hist_struct
from .common cimport X_BINNED_DTYPE_C
from .common cimport G_H_DTYPE_C
from ...utils._typedefs cimport uint8_t


# Notes:
Expand Down Expand Up @@ -79,13 +80,13 @@ cdef class HistogramBuilder:
G_H_DTYPE_C [::1] hessians
G_H_DTYPE_C [::1] ordered_gradients
G_H_DTYPE_C [::1] ordered_hessians
unsigned char hessians_are_constant
uint8_t hessians_are_constant
int n_threads

def __init__(self, const X_BINNED_DTYPE_C [::1, :] X_binned,
unsigned int n_bins, G_H_DTYPE_C [::1] gradients,
G_H_DTYPE_C [::1] hessians,
unsigned char hessians_are_constant,
uint8_t hessians_are_constant,
int n_threads):

self.X_binned = X_binned
Expand Down Expand Up @@ -130,7 +131,7 @@ cdef class HistogramBuilder:
int f_idx
int i
# need local views to avoid python interactions
unsigned char hessians_are_constant = self.hessians_are_constant
uint8_t hessians_are_constant = self.hessians_are_constant
int n_allowed_features = self.n_features
G_H_DTYPE_C [::1] ordered_gradients = self.ordered_gradients
G_H_DTYPE_C [::1] gradients = self.gradients
Expand Down Expand Up @@ -195,7 +196,7 @@ cdef class HistogramBuilder:
self.ordered_gradients[:n_samples]
G_H_DTYPE_C [::1] ordered_hessians = \
self.ordered_hessians[:n_samples]
unsigned char hessians_are_constant = \
uint8_t hessians_are_constant = \
self.hessians_are_constant

# Set histograms to zero.
Expand Down
50 changes: 25 additions & 25 deletions sklearn/ensemble/_hist_gradient_boosting/splitting.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ cdef struct split_info_struct:
Y_DTYPE_C gain
int feature_idx
unsigned int bin_idx
unsigned char missing_go_to_left
uint8_t missing_go_to_left
Y_DTYPE_C sum_gradient_left
Y_DTYPE_C sum_gradient_right
Y_DTYPE_C sum_hessian_left
Expand All @@ -41,7 +41,7 @@ cdef struct split_info_struct:
unsigned int n_samples_right
Y_DTYPE_C value_left
Y_DTYPE_C value_right
unsigned char is_categorical
uint8_t is_categorical
BITSET_DTYPE_C left_cat_bitset


Expand Down Expand Up @@ -168,11 +168,11 @@ cdef class Splitter:
const X_BINNED_DTYPE_C [::1, :] X_binned
unsigned int n_features
const unsigned int [::1] n_bins_non_missing
unsigned char missing_values_bin_idx
const unsigned char [::1] has_missing_values
const unsigned char [::1] is_categorical
uint8_t missing_values_bin_idx
const uint8_t [::1] has_missing_values
const uint8_t [::1] is_categorical
const signed char [::1] monotonic_cst
unsigned char hessians_are_constant
uint8_t hessians_are_constant
Y_DTYPE_C l2_regularization
Y_DTYPE_C min_hessian_to_split
unsigned int min_samples_leaf
Expand All @@ -188,15 +188,15 @@ cdef class Splitter:
def __init__(self,
const X_BINNED_DTYPE_C [::1, :] X_binned,
const unsigned int [::1] n_bins_non_missing,
const unsigned char missing_values_bin_idx,
const unsigned char [::1] has_missing_values,
const unsigned char [::1] is_categorical,
const uint8_t missing_values_bin_idx,
const uint8_t [::1] has_missing_values,
const uint8_t [::1] is_categorical,
const signed char [::1] monotonic_cst,
Y_DTYPE_C l2_regularization,
Y_DTYPE_C min_hessian_to_split=1e-3,
unsigned int min_samples_leaf=20,
Y_DTYPE_C min_gain_to_split=0.,
unsigned char hessians_are_constant=False,
uint8_t hessians_are_constant=False,
Y_DTYPE_C feature_fraction_per_split=1.0,
rng=np.random.RandomState(),
unsigned int n_threads=1):
Expand Down Expand Up @@ -307,14 +307,14 @@ cdef class Splitter:
cdef:
int n_samples = sample_indices.shape[0]
X_BINNED_DTYPE_C bin_idx = split_info.bin_idx
unsigned char missing_go_to_left = split_info.missing_go_to_left
unsigned char missing_values_bin_idx = self.missing_values_bin_idx
uint8_t missing_go_to_left = split_info.missing_go_to_left
uint8_t missing_values_bin_idx = self.missing_values_bin_idx
int feature_idx = split_info.feature_idx
const X_BINNED_DTYPE_C [::1] X_binned = \
self.X_binned[:, feature_idx]
unsigned int [::1] left_indices_buffer = self.left_indices_buffer
unsigned int [::1] right_indices_buffer = self.right_indices_buffer
unsigned char is_categorical = split_info.is_categorical
uint8_t is_categorical = split_info.is_categorical
# Cython is unhappy if we set left_cat_bitset to
# split_info.left_cat_bitset directly, so we need a tmp var
BITSET_INNER_DTYPE_C [:] cat_bitset_tmp = split_info.left_cat_bitset
Expand All @@ -334,7 +334,7 @@ cdef class Splitter:
int thread_idx
int sample_idx
int right_child_position
unsigned char turn_left
uint8_t turn_left
int [:] left_offset = np.zeros(n_threads, dtype=np.int32)
int [:] right_offset = np.zeros(n_threads, dtype=np.int32)

Expand Down Expand Up @@ -482,8 +482,8 @@ cdef class Splitter:
int n_allowed_features
split_info_struct split_info
split_info_struct * split_infos
const unsigned char [::1] has_missing_values = self.has_missing_values
const unsigned char [::1] is_categorical = self.is_categorical
const uint8_t [::1] has_missing_values = self.has_missing_values
const uint8_t [::1] is_categorical = self.is_categorical
const signed char [::1] monotonic_cst = self.monotonic_cst
int n_threads = self.n_threads
bint has_interaction_cst = False
Expand Down Expand Up @@ -622,7 +622,7 @@ cdef class Splitter:
cdef void _find_best_bin_to_split_left_to_right(
Splitter self,
unsigned int feature_idx,
unsigned char has_missing_values,
uint8_t has_missing_values,
const hist_struct [:, ::1] histograms, # IN
unsigned int n_samples,
Y_DTYPE_C sum_gradients,
Expand Down Expand Up @@ -658,7 +658,7 @@ cdef class Splitter:
Y_DTYPE_C sum_gradient_right
Y_DTYPE_C loss_current_node
Y_DTYPE_C gain
unsigned char found_better_split = False
uint8_t found_better_split = False

Y_DTYPE_C best_sum_hessian_left
Y_DTYPE_C best_sum_gradient_left
Expand Down Expand Up @@ -771,7 +771,7 @@ cdef class Splitter:
Y_DTYPE_C loss_current_node
Y_DTYPE_C gain
unsigned int start = self.n_bins_non_missing[feature_idx] - 2
unsigned char found_better_split = False
uint8_t found_better_split = False

Y_DTYPE_C best_sum_hessian_left
Y_DTYPE_C best_sum_gradient_left
Expand Down Expand Up @@ -851,7 +851,7 @@ cdef class Splitter:
cdef void _find_best_bin_to_split_category(
self,
unsigned int feature_idx,
unsigned char has_missing_values,
uint8_t has_missing_values,
const hist_struct [:, ::1] histograms, # IN
unsigned int n_samples,
Y_DTYPE_C sum_gradients,
Expand Down Expand Up @@ -890,7 +890,7 @@ cdef class Splitter:
unsigned int n_samples_left, n_samples_right
Y_DTYPE_C gain
Y_DTYPE_C best_gain = -1.0
unsigned char found_better_split = False
uint8_t found_better_split = False
Y_DTYPE_C best_sum_hessian_left
Y_DTYPE_C best_sum_gradient_left
unsigned int best_n_samples_left
Expand Down Expand Up @@ -1139,12 +1139,12 @@ cdef inline Y_DTYPE_C _loss_from_value(
"""
return sum_gradient * value

cdef inline unsigned char sample_goes_left(
unsigned char missing_go_to_left,
unsigned char missing_values_bin_idx,
cdef inline uint8_t sample_goes_left(
uint8_t missing_go_to_left,
uint8_t missing_values_bin_idx,
X_BINNED_DTYPE_C split_bin_idx,
X_BINNED_DTYPE_C bin_value,
unsigned char is_categorical,
uint8_t is_categorical,
BITSET_DTYPE_C left_cat_bitset) noexcept nogil:
"""Helper to decide whether sample should go to left or right child."""

Expand Down
6 changes: 3 additions & 3 deletions sklearn/linear_model/_sgd_fast.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ from cython cimport floating
from libc.math cimport exp, fabs, isfinite, log, pow, INFINITY

from .._loss._loss cimport CyLossFunction
from ..utils._typedefs cimport uint32_t
from ..utils._typedefs cimport uint32_t, uint8_t
from ..utils._weight_vector cimport WeightVector32, WeightVector64
from ..utils._seq_dataset cimport SequentialDataset32, SequentialDataset64

Expand Down Expand Up @@ -287,7 +287,7 @@ def _plain_sgd{{name_suffix}}(
double C,
double l1_ratio,
SequentialDataset{{name_suffix}} dataset,
const unsigned char[::1] validation_mask,
const uint8_t[::1] validation_mask,
bint early_stopping,
validation_score_cb,
int n_iter_no_change,
Expand Down Expand Up @@ -333,7 +333,7 @@ def _plain_sgd{{name_suffix}}(
l1_ratio=0 corresponds to L2 penalty, l1_ratio=1 to L1.
dataset : SequentialDataset
A concrete ``SequentialDataset`` object.
validation_mask : ndarray[unsigned char, ndim=1]
validation_mask : ndarray[uint8_t, ndim=1]
Equal to True on the validation set.
early_stopping : boolean
Whether to use a stopping criterion based on the validation set.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ from cython cimport floating, final, integral
from cython.operator cimport dereference as deref
from cython.parallel cimport parallel, prange
from ._classmode cimport WeightingStrategy
from ...utils._typedefs cimport intp_t, float64_t
from ...utils._typedefs cimport intp_t, float64_t, uint8_t

import numpy as np
from scipy.sparse import issparse
Expand All @@ -25,7 +25,7 @@ cdef class RadiusNeighborsClassMode{{name_suffix}}(RadiusNeighbors{{name_suffix}
intp_t outlier_label_index
bint outlier_label_exists
bint outliers_exist
unsigned char[::1] outliers
uint8_t[::1] outliers
object outlier_label
float64_t[:, ::1] class_scores
WeightingStrategy weight_type
Expand Down