Skip to content

MAINT better fused type names in loss module #27330

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions sklearn/_loss/_loss.pxd
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
# cython: language_level=3

# Fused types for y_true, y_pred, raw_prediction
ctypedef fused Y_DTYPE_C:
# Fused types for input like y_true, raw_prediction, sample_weights.
ctypedef fused floating_in:
double
float


# Fused types for gradient and hessian
ctypedef fused G_DTYPE_C:
# Fused types for output like gradient and hessian
# We use a different fused types for input (floating_in) and output (floating_out), such
# that input and output can have different dtypes in the same function call. A single
# fused type can only take on one single value (type) for all arguments in one function
# call.
ctypedef fused floating_out:
double
float

Expand Down
168 changes: 85 additions & 83 deletions sklearn/_loss/_loss.pyx.tp
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,8 @@ cdef inline double log1pexp(double x) noexcept nogil:

cdef inline void sum_exp_minus_max(
const int i,
const Y_DTYPE_C[:, :] raw_prediction, # IN
Y_DTYPE_C *p # OUT
const floating_in[:, :] raw_prediction, # IN
floating_in *p # OUT
) noexcept nogil:
# Thread local buffers are used to stores results of this function via p.
# The results are stored as follows:
Expand Down Expand Up @@ -744,7 +744,7 @@ cdef inline double_pair cgrad_hess_half_binomial(
double raw_prediction
) noexcept nogil:
# with y_pred = expit(raw)
# hessian = y_pred * (1 - y_pred) = exp(raw) / (1 + exp(raw))**2
# hessian = y_pred * (1 - y_pred) = exp( raw) / (1 + exp( raw))**2
# = exp(-raw) / (1 + exp(-raw))**2
cdef double_pair gh
gh.val2 = exp(-raw_prediction) # used as temporary
Expand Down Expand Up @@ -835,7 +835,9 @@ cdef class CyLossFunction:
"""
pass

cdef double_pair cy_grad_hess(self, double y_true, double raw_prediction) noexcept nogil:
cdef double_pair cy_grad_hess(
self, double y_true, double raw_prediction
) noexcept nogil:
"""Compute gradient and hessian.

Gradient and hessian of loss w.r.t. raw_prediction for a single sample.
Expand All @@ -862,10 +864,10 @@ cdef class CyLossFunction:

def loss(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[::1] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[::1] loss_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] loss_out, # OUT
int n_threads=1
):
"""Compute the pointwise loss value for each input.
Expand All @@ -892,10 +894,10 @@ cdef class CyLossFunction:

def gradient(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[::1] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[::1] gradient_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] gradient_out, # OUT
int n_threads=1
):
"""Compute gradient of loss w.r.t raw_prediction for each input.
Expand All @@ -922,11 +924,11 @@ cdef class CyLossFunction:

def loss_gradient(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[::1] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[::1] loss_out, # OUT
G_DTYPE_C[::1] gradient_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] loss_out, # OUT
floating_out[::1] gradient_out, # OUT
int n_threads=1
):
"""Compute loss and gradient of loss w.r.t raw_prediction.
Expand Down Expand Up @@ -960,11 +962,11 @@ cdef class CyLossFunction:

def gradient_hessian(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[::1] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[::1] gradient_out, # OUT
G_DTYPE_C[::1] hessian_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] gradient_out, # OUT
floating_out[::1] hessian_out, # OUT
int n_threads=1
):
"""Compute gradient and hessian of loss w.r.t raw_prediction.
Expand Down Expand Up @@ -1022,10 +1024,10 @@ cdef class {{name}}(CyLossFunction):

def loss(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[::1] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[::1] loss_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] loss_out, # OUT
int n_threads=1
):
cdef:
Expand All @@ -1048,11 +1050,11 @@ cdef class {{name}}(CyLossFunction):
{{if closs_grad is not None}}
def loss_gradient(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[::1] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[::1] loss_out, # OUT
G_DTYPE_C[::1] gradient_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] loss_out, # OUT
floating_out[::1] gradient_out, # OUT
int n_threads=1
):
cdef:
Expand Down Expand Up @@ -1080,10 +1082,10 @@ cdef class {{name}}(CyLossFunction):

def gradient(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[::1] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[::1] gradient_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] gradient_out, # OUT
int n_threads=1
):
cdef:
Expand All @@ -1105,11 +1107,11 @@ cdef class {{name}}(CyLossFunction):

def gradient_hessian(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[::1] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[::1] gradient_out, # OUT
G_DTYPE_C[::1] hessian_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[::1] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] gradient_out, # OUT
floating_out[::1] hessian_out, # OUT
int n_threads=1
):
cdef:
Expand Down Expand Up @@ -1158,18 +1160,18 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
# opposite are welcome.
def loss(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[:, :] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[::1] loss_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[:, :] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] loss_out, # OUT
int n_threads=1
):
cdef:
int i, k
int n_samples = y_true.shape[0]
int n_classes = raw_prediction.shape[1]
Y_DTYPE_C max_value, sum_exps
Y_DTYPE_C* p # temporary buffer
floating_in max_value, sum_exps
floating_in* p # temporary buffer

# We assume n_samples > n_classes. In this case having the inner loop
# over n_classes is a good default.
Expand All @@ -1181,7 +1183,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
Expand All @@ -1197,7 +1199,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
Expand All @@ -1218,26 +1220,26 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):

def loss_gradient(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[:, :] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[::1] loss_out, # OUT
G_DTYPE_C[:, :] gradient_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[:, :] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[::1] loss_out, # OUT
floating_out[:, :] gradient_out, # OUT
int n_threads=1
):
cdef:
int i, k
int n_samples = y_true.shape[0]
int n_classes = raw_prediction.shape[1]
Y_DTYPE_C max_value, sum_exps
Y_DTYPE_C* p # temporary buffer
floating_in max_value, sum_exps
floating_in* p # temporary buffer

if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
Expand All @@ -1256,7 +1258,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
Expand All @@ -1280,25 +1282,25 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):

def gradient(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[:, :] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[:, :] gradient_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[:, :] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[:, :] gradient_out, # OUT
int n_threads=1
):
cdef:
int i, k
int n_samples = y_true.shape[0]
int n_classes = raw_prediction.shape[1]
Y_DTYPE_C sum_exps
Y_DTYPE_C* p # temporary buffer
floating_in sum_exps
floating_in* p # temporary buffer

if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
Expand All @@ -1312,7 +1314,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
Expand All @@ -1329,26 +1331,26 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):

def gradient_hessian(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[:, :] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[:, :] gradient_out, # OUT
G_DTYPE_C[:, :] hessian_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[:, :] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[:, :] gradient_out, # OUT
floating_out[:, :] hessian_out, # OUT
int n_threads=1
):
cdef:
int i, k
int n_samples = y_true.shape[0]
int n_classes = raw_prediction.shape[1]
Y_DTYPE_C sum_exps
Y_DTYPE_C* p # temporary buffer
floating_in sum_exps
floating_in* p # temporary buffer

if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
Expand All @@ -1364,7 +1366,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
Expand All @@ -1387,26 +1389,26 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
# diagonal (in the classes) approximation as implemented above.
def gradient_proba(
self,
const Y_DTYPE_C[::1] y_true, # IN
const Y_DTYPE_C[:, :] raw_prediction, # IN
const Y_DTYPE_C[::1] sample_weight, # IN
G_DTYPE_C[:, :] gradient_out, # OUT
G_DTYPE_C[:, :] proba_out, # OUT
const floating_in[::1] y_true, # IN
const floating_in[:, :] raw_prediction, # IN
const floating_in[::1] sample_weight, # IN
floating_out[:, :] gradient_out, # OUT
floating_out[:, :] proba_out, # OUT
int n_threads=1
):
cdef:
int i, k
int n_samples = y_true.shape[0]
int n_classes = raw_prediction.shape[1]
Y_DTYPE_C sum_exps
Y_DTYPE_C* p # temporary buffer
floating_in sum_exps
floating_in* p # temporary buffer

if sample_weight is None:
# inner loop over n_classes
with nogil, parallel(num_threads=n_threads):
# Define private buffer variables as each thread might use its
# own.
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
Expand All @@ -1420,7 +1422,7 @@ cdef class CyHalfMultinomialLoss(CyLossFunction):
free(p)
else:
with nogil, parallel(num_threads=n_threads):
p = <Y_DTYPE_C *> malloc(sizeof(Y_DTYPE_C) * (n_classes + 2))
p = <floating_in *> malloc(sizeof(floating_in) * (n_classes + 2))

for i in prange(n_samples, schedule='static'):
sum_exp_minus_max(i, raw_prediction, p)
Expand Down