Skip to content

[MRG+1] Patch liblinear for sample_weights in LogisticRegression(and CV) #5274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 23, 2015

Conversation

MechCoder
Copy link
Member

This add sample_weights to the liblinear solver for LogisticRegression and LogisticRegressionCV. It had been already added to the other solvers in another PR

@MechCoder MechCoder force-pushed the liblinear_samples_weights branch 2 times, most recently from 0aff304 to a700d3c Compare September 15, 2015 21:59
@MechCoder MechCoder changed the title [WIP] Patch liblinear for sample_weights in LogisticRegression(and CV) [MRG] Patch liblinear for sample_weights in LogisticRegression(and CV) Sep 15, 2015
@MechCoder MechCoder force-pushed the liblinear_samples_weights branch from a700d3c to 2ef06ac Compare September 15, 2015 22:02
@MechCoder
Copy link
Member Author

ping @vstolbunov . Also @fabianp could you have a look, since you know this part of the code very well.

@MechCoder MechCoder force-pushed the liblinear_samples_weights branch from 2ef06ac to 268b1bc Compare September 15, 2015 22:06
@MechCoder
Copy link
Member Author

tests pass now.

@vstolbunov
Copy link
Contributor

I took a look last night and they had passed so I wasn't sure what the problem was?

@MechCoder
Copy link
Member Author

I think I forgot to build it locally and hence I got some segmentation faults. But now it's all right

@ogrisel
Copy link
Member

ogrisel commented Sep 23, 2015

@TomDLT you might want to review this.

# Test the above for l1 penalty and l2 penalty with dual=True.
# since the patched liblinear code is different.
clf_cw = LogisticRegression(
solver="liblinear", fit_intercept=False, class_weight={0:1, 1:2},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

{0: 1, 1: 2}

@TomDLT
Copy link
Member

TomDLT commented Sep 24, 2015

You should also compare the results with liblinear and other solvers, like in this test.

And testing it, it reveals that liblinear does not handles sample_weights with integers.
The problem is that this check is not done when liblinear skip everything here.

clf_sw_none.fit(X, y)
clf_sw_ones = LR(solver=solver, fit_intercept=False)
clf_sw_ones.fit(X, y, sample_weight=np.ones(y.shape[0]))
assert_array_almost_equal(clf_sw_none.coef_, clf_sw_ones.coef_, decimal=4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line too long

@TomDLT
Copy link
Member

TomDLT commented Sep 24, 2015

It looks pretty good to me. (yet I am not a C++ master)

@MechCoder MechCoder force-pushed the liblinear_samples_weights branch from 6e6ec68 to 867596b Compare September 24, 2015 17:56
@MechCoder
Copy link
Member Author

@TomDLT thanks for your reviews. I've fixed it up.

Any second reviewers?
cc @jnothman @ogrisel

@MechCoder
Copy link
Member Author

@TomDLT Can you update the PR to MRG+1 if you are happy?

(Btw, I don't know what your definition of a C++ master is, but whatever it is I'm not one either :P)

@MechCoder MechCoder force-pushed the liblinear_samples_weights branch from 867596b to 5134286 Compare September 24, 2015 18:02
@TomDLT TomDLT changed the title [MRG] Patch liblinear for sample_weights in LogisticRegression(and CV) [MRG+1] Patch liblinear for sample_weights in LogisticRegression(and CV) Sep 25, 2015
@ogrisel
Copy link
Member

ogrisel commented Oct 19, 2015

If you are looking for a C++ master, @larsmans is a good candidate :)

@@ -10,7 +10,7 @@
from ..base import BaseEstimator, ClassifierMixin, ChangedBehaviorWarning
from ..preprocessing import LabelEncoder
from ..multiclass import _ovr_decision_function
from ..utils import check_array, check_random_state, column_or_1d
from ..utils import check_array, check_consistent_length, check_random_state, column_or_1d
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8, put the check_consistent_length import on its own line.

@MechCoder MechCoder force-pushed the liblinear_samples_weights branch 2 times, most recently from 75758f6 to 84668a7 Compare October 19, 2015 19:05
@jnothman
Copy link
Member

You should also be supporting in LinearSVC.

@MechCoder
Copy link
Member Author

I thought that was for LinearSVC

@jnothman
Copy link
Member

I thought that was for LinearSVC

I'd thought it was a general patch to liblinear, but maybe..

@MechCoder MechCoder force-pushed the liblinear_samples_weights branch from b11922c to c01faae Compare October 20, 2015 14:05
@MechCoder
Copy link
Member Author

@jnothman So we'll merge this and add support for other solvers later?

@TomDLT
Copy link
Member

TomDLT commented Oct 21, 2015

@ogrisel @larsmans @jnothman what are your views on this PR?

@MechCoder MechCoder force-pushed the liblinear_samples_weights branch from c01faae to 9d8d7e4 Compare October 23, 2015 04:07
@MechCoder
Copy link
Member Author

Rebased. Would be great if someone can give a final +1

@agramfort
Copy link
Member

played with this a bit and it worked great. Merging.

agramfort added a commit that referenced this pull request Oct 23, 2015
[MRG+1] Patch liblinear for sample_weights in LogisticRegression(and CV)
@agramfort agramfort merged commit 1c5d6d7 into scikit-learn:master Oct 23, 2015
@amueller
Copy link
Member

the pyx file doesn't compile with cython 0.21, 0.22 or 0.23. You used 0.20, I'll try that next. I'm pretty scared of the casting that is going on there. This was found by @arthurmensch in #5557

@amueller
Copy link
Member

Installed 0.20, still doesn't compile. liblinear.pyx:39:35: Cannot assign type 'char *' to 'double *' did you forget to call cython?

@MechCoder
Copy link
Member Author

I think I just forgot to push the generated C files. Just a second.

@niteshroyal
Copy link

niteshroyal commented Jun 9, 2018

How objective function changes in the case of sample_weight for Logistic Regression? Can you please provide the mathematical expression?

I assume objective function changes like this

E(\mathbf{w}) = - \sum_{n=1}^{N} {s_n t_n \ln y_n + (1-s_n t_n) \ln(1-y_n)}

where s_n is the sample_weight of nth sample.

The above equation modified according to equation 4.90 of Christopher Bishop's PRML book.

Clarification: The equation is written in Latex. Could not post image

@amueller
Copy link
Member

amueller commented Jun 9, 2018

@niteshroyal this is not the right place to ask usage questions, see http://scikit-learn.org/dev/faq.html#what-s-the-best-way-to-get-help-on-scikit-learn-usage

@jnothman
Copy link
Member

jnothman commented Jun 9, 2018 via email

@memeplex
Copy link

memeplex commented Aug 16, 2018

ordinarily, weighting means solving an objective that is equivalent to having the samples repeated in proportion to their weight

@jnothman I have seen that when the class weights are too imbalanced adding more degrees of freedom to the model won't always result in a lower (accordingly weighted) log_loss. So I suspect that, except for a numerical issue that I'm unaware of, the objective function is not exactly the same than the one log_loss is evaluating (equivalent to "samples repeated in proportion to their weight"). Of course, this isn't a generalization problem, the loss was computed over the training set. The parameter C is set to 1e30 so that regularization is virtually disabled. If this is an unexpected behavior (it is for me! I could provide the data and model to reproduce the behavior.

@memeplex
Copy link

Ok, I think it was a numerical issue indeed, playing with the tol parameter I managed to get a decreasing loss schedule for an increasing degrees of freedom one. This is for a very low rate of conversion dataset (p < 1e-4) so the convergence criterion is critical.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants