-
-
Notifications
You must be signed in to change notification settings - Fork 26k
ENH log1pexp for binomial loss in loss module #21814
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
cdef inline double log1pexp(double x) nogil: | ||
if x <= -37: | ||
return exp(x) | ||
elif x <= 18: | ||
elif x <= -2: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How was -2 chosen here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll add a shorter note.
The longer story is at the end of section 2 in https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf, arguing for log(2)
as cutoff for the function log(1-exp(-x))
.
I tested this -2
on our function and it gives a difference of 1e-16.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import numpy as np
def diff(x):
"""Return abs diff and rel diff"""
use_log = np.log(1 + np.exp(x))
use_log1p = np.log1p(np.exp(x))
return use_log - use_log1p, use_log / use_log1p - 1
for x in [0, -1, -2, -3]:
print(f"x={x}: {diff(x)}")
results in
x=0: (0.0, 0.0)
x=-1: (0.0, 0.0)
x=-2: (1.1102230246251565e-16, 8.881784197001252e-16)
x=-3: (-1.0408340855860843e-16, -2.1094237467877974e-15)
Some timings give rng = np.random.default_rng(0)
y_true = rng.binomial(1, 0.5, size=100_000).astype(np.float64)
raw = rng.standard_normal(100_000, dtype=np.float64)
%%timeit -r10 -n100
cy_logloss_stable(y_true, raw) 2.22 ms ± 93.8 µs %%timeit -r10 -n100
cy_logloss_stable_fast(y_true, raw) 1.72 ms ± 81.9 µs The reason is of course that %%cython -3
# distutils: extra_compile_args = -O3
# cython: cdivision=True
# cython: boundscheck=False
# cython: wraparound=False
import cython
from cython.parallel import prange
import numpy as np
from libc.math cimport exp, log, log1p
cimport numpy as np
np.import_array()
# Numerically stable log(1 + exp(x))
cdef inline double log1pexp(double x) nogil:
if x <= -37:
return exp(x)
elif x <= 18:
return log1p(exp(x))
elif x <= 33.3:
return x + exp(-x)
else:
return x
# Faster version of numerically stable log(1 + exp(x))
cdef inline double log1pexp_fast(double x) nogil:
if x <= -37:
return exp(x)
elif x <= -2:
return log1p(exp(x))
elif x <= 18:
return log(1 + exp(x))
elif x <= 33.3:
return x + exp(-x)
else:
return x
cdef inline double c_logloss(double y_true, double raw) nogil:
return log1pexp(raw) - y_true * raw
cdef inline double c_logloss_fast(double y_true, double raw) nogil:
return log1pexp_fast(raw) - y_true * raw
def cy_logloss_stable(double[::1] y_true, double[::1] raw):
cdef:
int n_samples
int i
double[::1] out = np.empty_like(y_true)
n_samples = y_true.shape[0]
for i in range(n_samples):
out[i] = c_logloss(y_true[i], raw[i])
return np.asarray(out)
def cy_logloss_stable_fast(double[::1] y_true, double[::1] raw):
cdef:
int n_samples
int i
double[::1] out = np.empty_like(y_true)
n_samples = y_true.shape[0]
for i in range(n_samples):
out[i] = c_logloss_fast(y_true[i], raw[i])
return np.asarray(out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@jjerphan @thomasjpfan This is a tiny change and would help me to move forward. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks a lot! |
Reference Issues/PRs
Follow-up of #20567
What does this implement/fix? Explain your changes.
This PR improves helper function
log1pexp
a bit, which speeds upHalfBinomialLoss.loss(..)
Any other comments?
Ideally, this PR is merged before #21808 and #20811. It will improve their benchmarks.