Skip to content

A common private module for differentiable loss functions used as objective functions in estimators #15123

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
rth opened this issue Oct 3, 2019 · 34 comments
Labels

Comments

@rth
Copy link
Member

rth commented Oct 3, 2019

Currently each model that needs it defines their own losses, it might be useful to put them all in one place to see if anything could be re-used in the future.

In particular, losses are defined in,

  • HistGradientBoosting: sklearn/ensemble/_hist_gradient_boosting/loss.py
  • GradientBoosting: sklearn/ensemble/_gb_losses.py
  • SGD: sklearn/linear_model/sgd_fast.pyx
  • GLM: Minimal Generalized linear models implementation (L2 + lbfgs) #14300 would add sklearn/linear_model/_glm/distributions.py
  • MLP: sklearn/neural_network/_base.py:LOSS_FUNCTIONS
  • losses implementations in sklearn.metrics
  • and somewhat related RandomForest: sklearn/tree/_criterion.pyx

This issues proposed to progressively put most of them (except for RF criterions) in a single private folder sklearn/_loss to see what could be made less redundant there.

In particular, for the GLM PR in #1430 this would allow breaking circular dependency between the sklearn.linear_model and sklearn.metrics modules #14300 (comment)

@NicolasHug I don't know if you already had some plans about unifying new and legacy gradient boosting losses and if this would go in the right direction. cc @ogrisel

Weakly related to #5044 #3481 #5368

@NicolasHug
Copy link
Member

I'd be happy to unify everything... But that looks like a hard problem.

The first thing that we can reasonably unify are the __call__ methods. But even here, implementations differ on small details. For example, for the old GBDTs the LS loss is a proper LS loss, while for the histogram-based GBDTs it's a half LS (makes the rest of the code simpler).

If we want to unify more things, the problem gets even harder. E.g. gradient computation: with respect to what, predictions or parameters? predicted classes or predicted decision function? etc... And of course, input shape issues...

@rth
Copy link
Member Author

rth commented Oct 3, 2019

I agree that it's not a trivial problem, maybe even more so for GBDT, and I'm not sure that everything can indeed be unified. But say unifying MLP losses with SGD losses should be relatively easy I think; we could start there.

As long as everything is private, moving incrementally could be doable.

@NicolasHug
Copy link
Member

NicolasHug commented Oct 3, 2019

unifying MLP losses with SGD losses should be relatively easy

Briefly looking at it, MLP defines a half LS loss for numpy arrays in pure Python, while the SGD losses are point-wise Cythonized methods (also a half LS).

EDIT: they are also nogil

@rth
Copy link
Member Author

rth commented Oct 3, 2019

Indeed. Well we could have one cdef method for scalar nogil calls and one cpdef for vectorized ones in the same class. That would also help make sure they are consistent.

@rth rth closed this as completed Oct 3, 2019
@rth rth reopened this Oct 3, 2019
@NicolasHug
Copy link
Member

I agree this is a bit annoying to have all those implementations in different places, but OTOH writing losses is in general pretty easy and bugs are easily caught.

So I wouldn't make it a high priority right now maybe

@ogrisel
Copy link
Member

ogrisel commented Oct 15, 2019

So I wouldn't make it a high priority right now maybe

Duplication was bearable so far but personally I would like to extend Tweedie regression to HistGBRT (and probably classical GBRT as well) and MLP and for this loss family we probably need to factorize command private functions as they are significantly more complex than (half) least squares.

@ogrisel ogrisel changed the title A common module for losses A common private module for differentiable loss functions used as objective functions in estimators Oct 15, 2019
@lorentzenchr
Copy link
Member

Should the link function be tied together with/incorporated into the loss class or should they be independent?

Example binary log-loss/cross entropy:
Let y_pred be the predicted probability for outcome y_true=+1. Then one usually has y_pred = expit(z) with link function logit. Here, z is the internally estimated or predicted object. For a linear model, z=X@coef. Should the loss function then depend

  1. on z and automatically apply the inverse link function expit, i.e. loss(y_true, z) = -(y_true==1) * log(expit(z)) - (y_true != 1) * log(1-expit(z))
    ; or
  2. on y_pred, i.e. loss(y_true, y_pred) = -(y_true==1) * log(y_pred) - (y_true != 1) * log(1-y_pred)

With version 1, sometimes simplifications and more robust implementations are possible, see

loss = np.logaddexp(0, raw_predictions) - y_true * raw_predictions

Note, that raw_prediction=z.

On top, a gradient function in version 1 would automatically involve the derivative of the link function. For optimization this variant might be advantageous, the opposite for validation metrics.

@jnothman
Copy link
Member

jnothman commented Mar 16, 2020 via email

@ogrisel
Copy link
Member

ogrisel commented Mar 16, 2020 via email

@lorentzenchr
Copy link
Member

lorentzenchr commented Apr 2, 2020

What do you think about the following design?
The goal is that at least the losses of histogram gradient boosting and (generalized) linear models can be replaced.

1. Link functions

Close to what we already have in linear_mode/_glm/link.py.

class BaseLink(ABC):
    """Abstract base class for differentiable, invertible link functions.

        Convention:
            - link function g: raw_prediction = g(y_pred)
            - inverse link h: y_pred = h(raw_prediction)
    """
    @abstractmethod
    def link(self, y_pred, out=None):

    @abstractmethod
    def derivative(self, y_pred, out=None):
    
    @abstractmethod
    def inverse(self, raw_prediction, out=None):

    @abstractmethod
    def inverse_derivative(self, raw_prediction, out=None):

As an example, y_pred is probability for classification. For all linear models, raw_prediction = X @ coef.

2. Base Loss class

class BaseLoss(BaseLink):
    def __init__(self, hessians_are_constant=False):

    def __call__(self, y_true, raw_prediction, sample_weight=None):
        """Return the weighted average loss."""

    @abstractmethod
    def pointwise_loss(self, y_true, raw_prediction, sample_weight=None,
                       out=None):
        """Return loss value for each input.""""

    @abstractmethod
    def fit_intercept_only(self, y_true, sample_weight=None):
        """Return raw_prediction of an intercept-only model."""

    def gradient(self, y_true, raw_prediction, sample_weight=None, out=None):
        """Calculate gradient of loss w.r.t. raw_prediction."""
        y_pred = self.link.inverse(raw_prediction)
        out = self.gradient_ypred(y_true=y_true, y_pred=y_pred,
                                  sample_weight=sample_weight, out=out)
        return out

    @abstractmethod
    def gradient_ypred(self, y_true, y_pred, sample_weight=None, out=None):
        """Calculate gradient of loss w.r.t. raw_prediction."""

    def gradient_hessian(self, y_true, raw_prediction, sample_weight=None,
                         gradient=None, hessian=None):
        """Calculate gradient and hessian of loss w.r.t. raw_prediction."""
        ...

    @abstractmethod
    def gradient_hessian_ypred(self, y_true, y_pred, sample_weight=None,
                               gradient=None, hessian=None):
        """Calculate gradient and hessian of loss w.r.t. raw_prediction."""

3. Specific loss classes

class BinaryCrossEntropy(LogitLink, BaseLoss):
...

Notes

What I like here, is that a hessian (for single target estimation) would be a diagonal 1d array which represents the diagonal 2d array W. For linear models, the full hessian with respect to the coefficients is then X.T @ W @ X. Having access to W instead of the full hessian can in fact be very advantageous for solvers. AFAIK, it's also what hgb needs.

@lorentzenchr
Copy link
Member

The more is think about the design of such a loss function module, the more I think one should distinguish between loss functions used for fitting/estimation and loss functions used for validation/comparison (modul sklearn.metrics). The reason is that fitting requires the most efficient code while model validation/scoring requires most interpretable code.

Example: The Poisson deviance has minimum at 0. This is nice for interpretation and allows to define an analog of R^2. But for fitting is is possible to omit some "constant" terms with the downside of not knowing the optimal point anymore.

@rth
Copy link
Member Author

rth commented May 10, 2020

Example: The Poisson deviance has minimum at 0. This is nice for interpretation and allows to define an analog of R^2. But for fitting is is possible to omit some "constant" terms with the downside of not knowing the optimal point anymore.

Except that we do print the loss occasionally e.g. to monitor convergence, and if it doesn't match the the documented loss it can lead to confusion.

@lorentzenchr
Copy link
Member

@rth Do you have a suggestion how to tackle this? Would it be enough to document the discrepancy?
Another example is the squared error: hgbt uses half squared error which gives nicer expressions for gradient and hessian, see

This actually computes the half least squares loss to simplify
.

@lorentzenchr
Copy link
Member

The main challenge, I think, is do define a common API. Some estimators expect single number output, others arrays. Some would like to use parallelized versions others the parallelism of their own, ...

@lorentzenchr
Copy link
Member

Among my main motivations are:

  • easier generalization/adaption of losses for other estimators
  • implementation of efficient solvers for GLMs without code bloat

@rth
Copy link
Member Author

rth commented May 10, 2020

Yes I think as long as used loss is documented it should be fine. It's a challenging project indeed :)

@lorentzenchr
Copy link
Member

lorentzenchr commented May 10, 2020

So let's play a little IT architect. Current use of loss functions (for single targets).

1. HistGradientBoosting

files in sklearn.ensemble._hist_gradient_boosting: loss.py, _loss.pyx

  • def __call__(self, y_true, raw_predictions, sample_weight) -> float:
    Computes the weighted average loss.
  • def pointwise_loss(self, y_true, raw_predictions) -> ndarray:
    Computes the per sample losses.
  • def update_gradients_and_hessians(self, gradients, hessians, y_true, raw_predictions, sample_weight) -> None:
    Computes gradient and hessian and writes them into float32 ndarrays gradient and hessian. Implemented in cython with loop over samples for i in prange(n_samples, schedule='static', nogil=True):
  • get_baseline_prediction(self, y_train, sample_weight, prediction_dim) -> float:
    Compute initial estimates for "raw_prediction", i.e. a intercept-only linear model.
  • Knows the notion of a link function.

2. GradientBoosting

file sklearn/ensemble/_gb_losses.py

  • def __call__(self, y, raw_predictions, sample_weight=None) -> float:
    Computes the weighted average loss.
  • def negative_gradient(self, y, raw_predictions, **kargs) -> ndarray:
  • def init_estimator(self) -> estimator:
    Sets a default estimator for initialization of "raw_prediction" via get_init_raw_predictions, which in turn just calls estimator.predict(X).
  • No hessians.
  • Some more functions specific to gradient boosted trees.

3. SGD

file sklearn/linear_model/_sgd_fast.pyx

  • Everything in cython in an extension type (cdef class).
  • cdef double loss(self, double p, double y) nogil:
    Compute loss for one single sample.
  • cdef double dloss(self, double p, double y) nogil:
    Compute gradient for one single sample.
    Note: Python version py_dloss calls dloss. Only used for testing.
  • def dloss(self, double p, double y):
    return self._dloss(p, y)
  • No hessians.

4. GLM

To make it short, gradients, hessians and link function as for HistGradientBoosting are sufficient.

5. MLP

file sklearn/neural_network/_base.py

  • def squared_loss(y_true, y_pred) -> float:
    Computes the per sample losses. Several such functions, one for each loss squared_loss, log_loss and binary_log_loss.
  • Gradients are (by chance 😏) hard coded in _backprop as y_pred - y_true.
  • No hessians.
  • Knows the notion of a link function (here "activation" function) and its derivatives, e.g. logistic(X) and inplace_logistic_derivative(Z, delta).

6. Trees

  • Everything in cython.
  • cdef double proxy_impurity_improvement(self) nogil:
    Computes proxy of loss improvement of possible splits.
  • cdef void children_impurity(self, double* impurity_left, double* impurity_right) nogil:
    Computes the loss improvement of a splits.
  • No gradients nor hessians.

This post might be updated in the future.

@lorentzenchr
Copy link
Member

My plan would be to have one working design for HistGradientBoosting and GLM.
This might already also work for MLP. Then, changing a sign might be enough to include GradientBoosting. Adding single sample gradients in cython makes SGD doable. I see a red flag for Trees, as the splitting criterion can often be optimized by some algebra. As mentioned earlier, I would not aim for losses in the metrics module, but one could add some tests that the losses are consistent (have same minimum, etc.).

@lorentzenchr
Copy link
Member

lorentzenchr commented May 28, 2020

Currently, I'm blocked by cython/cython#3607. My idea was for each loss class to implement only 3 single cython functions (loss, grad, hess) for scalar values (single samples) and then use virtual inheritance of the parent class to deliver these functions working on arrays. Something like

Edit: Changed cdef G_DTYPE cgradient(self, Y_DTYPE ...) into cdef double cgradient(self, double ...) as it does work and is not blocked by cython/cython#3607.

# fused type for target variables
ctypedef fused Y_DTYPE:
    cython.float
    cython.double

# fused type for gradients and hessians
ctypedef fused G_DTYPE:
    cython.float
    cython.double

cdef class cLossFunction:
    ...
    cdef double cgradient(self, double y_true, double raw_prediction) nogil:
        return 0
        
    def gradient(self,
                       Y_DTYPE[:] y_true,          # IN
                       Y_DTYPE[:] raw_prediction,  # IN
                       Y_DTYPE[:] sample_weight,   # IN
                       G_DTYPE[:] out,             # OUT
                       ):
        cdef:
            int n_samples
            int i

        n_samples = raw_prediction.shape[0]
        for i in prange(n_samples, schedule='static', nogil=True):
            out[i] = sample_weight[i] * self.cgradient(y_true[i], raw_prediction[i])


cdef class cHalfSquaredError(cLossFunction):
    """Half Squared Error with identity link."""
    ...
    cdef double cgradient(self, double y_true, double raw_prediction) nogil:
        return raw_prediction - y_true

The point is that HistGradientBoosting is longing for extension types as it uses different types for target y and gradient and hessian, for instance.

@lorentzenchr
Copy link
Member

lorentzenchr commented Oct 19, 2020

As long as cython/cython#3607 is unresolved, I would be interested in opinions to the following 3 workarounds:

  1. For the time being, just implement double.
    Edit: only double for cdef functions.
  2. Use C++ templates in order to have double and float versions (as Y_DTYPE and G_DTYPE).
  3. Implement everything twice with G_DTYPE as float and double, Y_DTYPE as double.

@NicolasHug
Copy link
Member

1 means we're not backward compatible, right?

For 2 and 3, unfortunately I don't really see how these options will effectively help reducing the complexity or redundancy of the losses code, which I believe was the original motivation here

@lorentzenchr
Copy link
Member

lorentzenchr commented Oct 26, 2020

News: cython/cython#3607 doesn't seem to be a blocker anymore, as long as the cdef functions like cdef cgradient do not use fused types. The biggest issue with this design, however, is that the virtual inheritance adds a (some percentage, to be verified) performance penalty.

Design reasoning:

  • cdef loss and gradient functions for SGD, operating on a single sample point.
    • Seems to be enough to implement only double.
  • def loss and gradient functions that operate on numpy arrays (or memoryviews). Ideally, they call the cdef versions in a loop.
    a) Ideally, these would only be defined once, not in every loss class again => much less code!
    b) Testing the def functions automatically tests the cdef versions => much less tests!
    c) Provide option for multithreading, i.e. use prange to iterate over samples.
    d) Provide float and double simultaneously.

This sounds a bit like the function dispatch mechanism for generating numpy ufuncs, except the prangeoption.

@lorentzenchr
Copy link
Member

lorentzenchr commented Nov 9, 2020

Updated 24.11.2020
I did some benchmarking for the log-loss calculated on ndarrays of shape (N,), here. It compares the same function implemented as:

  1. numpy np_logloss
  2. cython loop over single sample C function p_logloss
  3. cython loop via generic function with pointer to single sample C function p_logloss_generic
  4. cython loop in class method over static C member function (same as 2 but all within a class) Logloss.loss
  5. cython loop in class method via generic function with pointer to static C member function (analogon to 3 but all within a class) Logloss.loss
  6. Same as 4, but the loop is virtually inherited from a base class Logloss_virtual.loss

Time ratio of log-loss (compared to numpy np_logloss, smaller is better)

image

The pure numpy version is much faster. For N >= 1e4, p_logloss, p_logloss_generic and Logloss.loss perform similarly good. Virtual functions and, for N >= 1e4, also generic loop with function pointer to static member function have some significant overhead.

Time ratio of gradient of log-loss

image

Absolute timings with error bars in details below

Conclusion

Virtual inheritance has some significant overhead. For larger sample size, a generic loop with function pointers to static member functions has a similar performance overhead. On the other side, the generic loop over normal (non class member) does not suffer this penalty.

image
image

@lorentzenchr
Copy link
Member

Another question of mine is how important is parallelism, i.e. prange is used only in the histogram-based gradient boosting trees? @NicolasHug, @thomasjpfan or @ogrisel maybe?

@lorentzenchr
Copy link
Member

lorentzenchr commented Nov 24, 2020

Assumptions:
a) Loss, Gradient and Hessian have C implementations working on single points (1 sample).
b) L, G, H have a ndarray implementation as well.
c) Performance is important.
d) Parallelism by prange is important.
e) There are at least squared error, absolute error, poisson deviance, log-loss. Possibly many more.

With the above assumptions and the benchmarks, I propose the following sketched design:

ctypedef double (*fpointer)(double, double) nogil

cdef void generic_loop(fpointer f, double[::1] y_true, double[::1] raw, double[::1] out):
    with nogil:
        for i in range(size): # could be prange
            out[i] = f(y_true[i], raw[i])

cdef double c_logloss(double y_true, double raw) nogil:
    """Defines the log-loss for a single point"""

cdef class Logloss():
    
    cdef double loss_point(double y_true, double raw) nogil:
        return c_logloss(double y_true, double raw)
    
    def loss(self, double[::1] y_true, double[::1] raw):
        cdef double[::1] out = np.empty_like(y_true)
        generic_loop(c_logloss, y_true, y_raw, out)            
        return np.asarray(out)

Alternatives:

Just for fun and, unfortunately unresolved: have a look at this old mail thread [Cython] Wacky idea: proper macros.

@rth
Copy link
Member Author

rth commented Nov 26, 2020

Thanks for doing this analysis @lorentzenchr !

def loss and gradient functions that operate on numpy arrays (or memoryviews). Ideally, they call the cdef versions in a loop.

If we don't have this constraint, and just copy the pointwise implementation as needed with a range or a prange, would that simplify things? Or is the virtual inheritence still an issue?

With the above assumptions and the benchmarks, I propose the following sketched design:

The proposed design sounds reasonable to me but I'm still a bit surprised by the slowdown we see in the benchmarks above with respect to the plain numpy implementation. Do you have an idea what this could be due to?

Time ratio of log-loss (compared to numpy np_logloss, smaller is better)

In that example, if np_logloss is faster couldn't we just use it for the vectorized version? Or is the issue that we need to release GIL?

@lorentzenchr
Copy link
Member

def loss and gradient functions that operate on numpy arrays (or memoryviews). Ideally, they call the cdef versions in a loop.

If we don't have this constraint, and just copy the pointwise implementation as needed with a range or a prange, would that simplify things? Or is the virtual inheritence still an issue?

Maybe, I misunderstand you. My goal is to avoid implementing the range/prange every single time as it would be repeated #losses * #functions > 10 times at least (much depending on number of losses, #functions ~ 3).

The proposed design sounds reasonable to me but I'm still a bit surprised by the slowdown we see in the benchmarks above with respect to the plain numpy implementation. Do you have an idea what this could be due to?

The plain numpy gets slower with a numerically more stable solution. In the benchmark, it is np.log1p(np.exp(raw)) - y_true * raw. I guess, it is so fast because it can first vectorize np.exp(raw) using simd instructions (ufunc magic in numpy) and then again vectorize log1p(temporary). It uses, however, more memory, which is not much of a concern here.

In that example, if np_logloss is faster couldn't we just use it for the vectorized version? Or is the issue that we need to release GIL?

For thread parallelism via prange, we need to release the GIL and do the loop over single points ourselfves.

@rth
Copy link
Member Author

rth commented Nov 26, 2020

Maybe, I misunderstand you. My goal is to avoid implementing the range/prange every single time as it would be repeated #losses * #functions > 10 times at least (much depending on number of losses, #functions ~ 3).

As a default approach that sounds reasonable. I'm just saying that in some cases wring that code twice without factorizing it into a function might also help some optimizations in the for loops: either because of auto-vectorization (with the use of SIMD intrinsics) by the compiler when per-computing some quantities, or because it's more cache friendly (e.g. due to use of BLAS).

For instance with the hinge loss implemented for arrays, it's probably enough to allocate an array of zeros, and then set all the elements above the threshold. There is no need to set to 0 the elements that are already zero, which would be done in the factorized function. There OpenMP would also likely not help, as it's a very simple and memory bound operations. In general I'm a bit weary of prange used everywhere, unless one can demonstrate that it's indeed a performance bottleneck (I'm not too familiar with the HGBDT code, maybe it is there), as on multicore systems (16+ CPU cores) it leads to a lot of wasted resources and sometimes slower performance (CPU oversubscriptuon etc)

For thread parallelism via prange, we need to release the GIL and do the loop over single points ourselfves.

In the loop indeed, but the overall method/function doesn't as far as I can tell (e.g. p_logloss in the notebook example), since we need to allocate the output array. So if the performance is similar or better, and there isn't too much repeated code, we could potentially have used plain numpy there as well.

Overall I would be +1 to start with the approach you proposed with factorized function, but without OpenMP. And then tune performance if needed on a case by case basis.

@lorentzenchr
Copy link
Member

I upated my notebook with a numerically more stable implementation of the log loss in Chapter 2. In this case the own implementation wins over pure numpy.

@lorentzenchr
Copy link
Member

How to best get such a deep code change in? I propose one PR for the new _loss module alone, then one follow-up PR per module. Difficulties with this approach:

Big advantage is that this way is more digestible than one huge PR.

@lorentzenchr
Copy link
Member

In a prototype with a generic function wrapper to carry out the loop, I got consistently worse performances than with a manual loop. This is consistent with the findings of the above linked notebook for larger sample sizes.

Example: BinaryCrossEntropy update_gradients_and_hessians with 8 cores and 1e5 samples (mean ± std. dev. of 10 runs, 500 loops each):
HGBT: 278 µs ± 56.5 µs per loop
Function Wrapper: 317 µs ± 49.4 µs per loop
Manual Loop (Like HGBT): 278 µs ± 36.2 µs per loop

This brings me back to manual loops resulting in a lot of code duplication. But a PR with worse performance won't be accepted anyway, I think.

@rth About prange: I intend to pass n_threads everywhere so that the caller has full control.

@rth
Copy link
Member Author

rth commented Dec 15, 2020

But a PR with worse performance won't be accepted anyway, I think.

Noticeably worse performance on the estimator training probably not, but the question is how much of the estimator fit time is spent computing the loss and gradients. The small slow down on the loss calculation might not matter much for instance.

I intend to pass n_threads everywhere so that the caller has full control.

Are you sure n_threads=1 wouldn't create a thread pool in Cython and would be equivalent in performance to a serial function?

How to best get such a deep code change in? [..] I propose one PR for the new _loss module alone, then one follow-up PR per module.

Sounds good. The first PR doesn't have to be exhaustive either, if take into account two types of estimators that would already be great. All of this is (or should be propitiate) so several rounds of smallish iterations might indeed be faster to merge.

@lorentzenchr
Copy link
Member

We now have #20567 merged. Let's close this issue when some modules make use of it, e.g. (histogram) gradient bossting and linear models.

@lorentzenchr
Copy link
Member

lorentzenchr commented Feb 25, 2022

Meanwhile, the new loss module is used for HGBT #20811 and in LogisticRegression #21808.

Open are:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Discussion
Development

Successfully merging a pull request may close this issue.

6 participants