Skip to content

ENH replace loss module Gradient boosting #26278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

lorentzenchr
Copy link
Member

@lorentzenchr lorentzenchr commented Apr 24, 2023

Reference Issues/PRs

Closes #25964.

What does this implement/fix? Explain your changes.

This replaces the losses from _gb_losses.py with the once of the common private loss submodule sklearn._loss in a very backward compatible way. Some factors of the 2 loss implementations differ by a factor of 2. Those factors are here accounted for.

Any other comments?

  1. No use of parallelism of the new loss functions is activated yet. This is deferred to later PRs (and contributor capacity).
  2. Should we discuss the backward compatibility for attributes oob_improvement_ and train_score_ wrt the above mentioned constant factor of the loss? As oob_scores_ and oob_score_ is about to be introduced with 1.3, we could still change them.

@lorentzenchr
Copy link
Member Author

lorentzenchr commented Apr 24, 2023

A very simple benchmark for binary classification gives:

PR:   768 ms ± 9.76 ms
main: 937 ms ± 128 ms
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier
n_samples, n = 10, 10_000
y = np.tile(np.arange(n_samples) % 2, n)
x1 = np.minimum(y, n_samples / 2)
x2 = np.minimum(-y, -n_samples / 2)
X = np.c_[x1, x2]
%timeit GradientBoostingClassifier(n_estimators=100).fit(X, y)

Same for multiclass classification (10 classes) gives:

PR:   9.53 s ± 76.7
main: 28.8 s ± 324 ms
import numpy as np
from sklearn.ensemble import GradientBoostingClassifier
n_samples, n = 10, 10_000
y = np.tile(np.arange(n_samples), n)
x1 = np.minimum(y, n_samples / 2)
x2 = np.minimum(-y, -n_samples / 2)
X = np.c_[x1, x2]
%timeit GradientBoostingClassifier(n_estimators=100).fit(X, y)

@lorentzenchr
Copy link
Member Author

Given the size of this PR ...

Therefore, I provided the different test_XXX_exact_backward_compat tests.

The exact origin for the necessity of this change is unclear. The train_score_ of the GBT inside the pipeline is exactly the same for this branch and 1.2.2.
try:
return numerator / denominator
except FloatingPointError:
return 0.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For codecov, can we add a small test to trigger the divide by zero?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to find a case with with GB estimator, but that's hard. So I simply added test_safe_divide.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently, we have a nice example that check this one ;)

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the updates! LGTM

@thomasjpfan thomasjpfan added the Waiting for Second Reviewer First reviewer is done, need a second one! label Aug 27, 2023
@lorentzenchr
Copy link
Member Author

Dear future 2nd reviewer
This is much easier to review than it looks by just having a look at the tests.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @lorentzenchr. A few comments.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the updates @lorentzenchr! Looks good now. Just a few minor comments.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @lorentzenchr

@OmarManzoor OmarManzoor merged commit 5dbb8f5 into scikit-learn:main Sep 7, 2023
@ogrisel ogrisel deleted the gradient_boosting_common_loss branch September 7, 2023 12:48
@lesteve
Copy link
Member

lesteve commented Sep 7, 2023

It looks like one of the example broke after merging this PR, namely examples/ensemble/plot_gradient_boosting_regularization.py, because of NaNs.

From build log:

Unexpected failing examples:
/home/circleci/project/examples/ensemble/plot_gradient_boosting_regularization.py failed leaving traceback:
Traceback (most recent call last):
  File "/home/circleci/project/examples/ensemble/plot_gradient_boosting_regularization.py", line 77, in <module>
    test_deviance[i] = 2 * log_loss(y_test, y_proba[:, 1])
  File "/home/circleci/project/sklearn/utils/_param_validation.py", line 211, in wrapper
    return func(*args, **kwargs)
  File "/home/circleci/project/sklearn/metrics/_classification.py", line 2831, in log_loss
    y_pred = check_array(
  File "/home/circleci/project/sklearn/utils/validation.py", line 958, in check_array
    _assert_all_finite(
  File "/home/circleci/project/sklearn/utils/validation.py", line 123, in _assert_all_finite
    _assert_all_finite_element_wise(
  File "/home/circleci/project/sklearn/utils/validation.py", line 172, in _assert_all_finite_element_wise
    raise ValueError(msg_err)
ValueError: Input contains NaN.

@glemaitre
Copy link
Member

@lesteve has been faster to me to report.
I will try to make a fix if I find what is the issue :).

@glemaitre
Copy link
Member

glemaitre commented Sep 7, 2023

OK so it is due to the _safe_divide where we expect numerator/denominator to raise an error using the errstate. However, it will not work when dividing two scalar np.float64:

In [5]: with np.errstate(divide="raise"):
   ...:     np.float64(0.0) / np.float64(0.0)
   ...: 
<ipython-input-5-941bf0f2d1fa>:2: RuntimeWarning: invalid value encountered in scalar divide
  np.float64(0.0) / np.float64(0.0)

In [6]: with np.errstate(divide="raise"):
   ...:     np.array([1.0, 2.0]) / np.float64(0.0)
   ...: 
---------------------------------------------------------------------------
FloatingPointError                        Traceback (most recent call last)
Cell In[6], line 2
      1 with np.errstate(divide="raise"):
----> 2     np.array([1.0, 2.0]) / np.float64(0.0)

FloatingPointError: divide by zero encountered in divide

Looking at the code, it seems that we expect to have the second case. If we are sure to always have some scalar, then we could convert them to a Python scalar that would always raise an error.

@glemaitre
Copy link
Member

Actually, the only case that go side ways is when both the numerator and denominator goes to 0.0. I assume that it can happen when at the previous iteration we return 0.0 because of the FloatingPointError.

REDVM pushed a commit to REDVM/scikit-learn that referenced this pull request Nov 16, 2023
iamDecode added a commit to iamDecode/sklearn-pmml-model that referenced this pull request Apr 14, 2024
iamDecode added a commit to iamDecode/sklearn-pmml-model that referenced this pull request Apr 14, 2024
iamDecode added a commit to iamDecode/sklearn-pmml-model that referenced this pull request Apr 14, 2024
valeriy42 added a commit to elastic/eland that referenced this pull request Nov 11, 2024
Introduce a warning indicating that exporting data frame analytics models as ESGradientBoostingModel subclasses is deprecated and will be removed in version 9.0.0.

The implementation of ESGradientBoostingModel relies on importing undocumented private classes that were changed in 1.4 to scikit-learn/scikit-learn#26278. This dependency makes the code difficult to maintain, while the functionality is not widely used by users. Therefore, we will deprecate this functionality in 8.16 and remove it completely in 9.0.0. 

---------

Co-authored-by: Quentin Pradet <quentin.pradet@elastic.co>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:ensemble Waiting for Second Reviewer First reviewer is done, need a second one!
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Use common loss module in gradient boosting
5 participants