Skip to content

[MRG + 1] Improve tests for sample_weight in LinearRegression and Ridge #5526

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 11, 2015

Conversation

giorgiop
Copy link
Contributor

I have noticed in #5357 that a test introduced in 0.17 was not correctly checking for the effect of sample_weight in LinearRegression. To reproduce the issue on master:

import numpy as np
from sklearn.linear_model.base import LinearRegression

rng = np.random.RandomState(0)
y = rng.randn(6)
X = rng.randn(6, 5)
w = 1.0 + rng.rand(6)

clf = LinearRegression()
clf.fit(X, y, sample_weight=w)
coefs1 = clf.coef_

scaled_y = y * np.sqrt(w)
scaled_X = X * np.sqrt(w)[:, np.newaxis]
clf.fit(scaled_X, scaled_y)
coefs2 = clf.coef_

print(coefs1)
[ 4.58791686 -4.2095038   0.39031788  3.2727146  -0.17386704]
print(coefs2)
[ 3.69763237 -3.64824351  0.367363    2.97550307 -0.44881672]

@giorgiop giorgiop changed the title BUG fix weight in LinearRegression BUG fix sample_weight in LinearRegression Oct 22, 2015
@giorgiop giorgiop force-pushed the fix-weight-linearregression branch from b58187f to 66633b8 Compare October 22, 2015 12:19
@amueller amueller changed the title BUG fix sample_weight in LinearRegression [MRG] BUG fix sample_weight in LinearRegression Oct 22, 2015
@amueller amueller added this to the 0.17 milestone Oct 22, 2015
@amueller
Copy link
Member

ping @sonnyhu @agramfort @ogrisel ?

@amueller
Copy link
Member

or @GaelVaroquaux who reviewed the PR together with myself.

@amueller
Copy link
Member

LGTM

@amueller amueller changed the title [MRG] BUG fix sample_weight in LinearRegression [MRG + 1] BUG fix sample_weight in LinearRegression Oct 22, 2015
@eickenberg
Copy link
Contributor

hmm, shouldn't this also show up in @ainafp 's common tests?

@eickenberg
Copy link
Contributor

#5515

@sonnyhu
Copy link

sonnyhu commented Oct 22, 2015

👍

if sample_weight is not None:
# Sample weight can be implemented via a simple rescaling.
X, y = _rescale_data(X, y, sample_weight)

X, y, X_mean, y_mean, X_std = self._center_data(
X, y, self.fit_intercept, self.normalize, self.copy_X,
sample_weight=None)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not equivalent as now you rescale the non-centered data. Which means that sample weight interact with intercept.

can you clarify why it's more correct this way?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or I saw #5357 (comment)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how to update sci kit to get this bug fixed?

On Thu, Oct 22, 2015 at 9:56 AM, Alexandre Gramfort <
notifications@github.com> wrote:

In sklearn/linear_model/base.py
#5526 (comment)
:

     if sample_weight is not None:
         # Sample weight can be implemented via a simple rescaling.
         X, y = _rescale_data(X, y, sample_weight)
  •    X, y, X_mean, y_mean, X_std = self._center_data(
    
  •        X, y, self.fit_intercept, self.normalize, self.copy_X,
    
  •        sample_weight=None)
    

or I saw #5357 (comment)
#5357 (comment)


Reply to this email directly or view it on GitHub
https://github.com/scikit-learn/scikit-learn/pull/5526/files#r42750244.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Sandy4321 we are working on solving the issue at the moment. In the meantime, if you want to use sample weights, you can instead multiply X(row by row) and y with the square root of the weights, as it is done in the code above. Fitting LinearRegression without sample weights still runs without issues.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see , but question is : how technically to do this, simple to use pip to
reinstall?

On Thu, Oct 22, 2015 at 10:42 AM, Giorgio Patrini notifications@github.com
wrote:

In sklearn/linear_model/base.py
#5526 (comment)
:

     if sample_weight is not None:
         # Sample weight can be implemented via a simple rescaling.
         X, y = _rescale_data(X, y, sample_weight)
  •    X, y, X_mean, y_mean, X_std = self._center_data(
    
  •        X, y, self.fit_intercept, self.normalize, self.copy_X,
    
  •        sample_weight=None)
    

@Sandy4321 https://github.com/Sandy4321 we are working on solving the
issue at the moment. In the meantime, if you want to use sample weights,
you can instead multiply X(row by row) and y with the square root of the
weights, as it is done in the code above. Fitting LinearRegression
without sample weights still runs without issues.


Reply to this email directly or view it on GitHub
https://github.com/scikit-learn/scikit-learn/pull/5526/files#r42756415.

@giorgiop
Copy link
Contributor Author

@agramfort We are reducing a weighted least square problem to a non-weigthed by first performing a change of variables. In the new variables, normalization/centering are weights-independent.

I believe intercept_ has to interact with sample_weight. If we think to the intercept as a dummy column of X, from a linear algebra point of view, a diagonal matrix of the sample weights multiplies that column too.

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Oct 22, 2015 via email

@giorgiop
Copy link
Contributor Author

It is. We are minimizing a (weighted) L2-norm here. Indeed, the current implementation does the trick a-priori, before the change of variable. See here.

@giorgiop giorgiop force-pushed the fix-weight-linearregression branch 2 times, most recently from 70ea898 to c1c35d0 Compare October 22, 2015 14:42
@MechCoder
Copy link
Member

I think I agree with @giorgiop

More simply, if we just take the loss function as \sum_{i=n} weight * (y - X'w -c)**2 then even the intercept gets weighted by a sqrt(weight) term no?


for n_samples, n_features in ((6, 5), (5, 10)):
for fit_intercept in [True, False]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just be a big with fit_intercept=True right? But I guess it is good to test.

@eickenberg
Copy link
Contributor

iirc when you do the orthogonalization against a constant (ie center), in
order to be able to remove the intercept from the optimization, the
recalculation later on incurs the multiplication of twice that sqrt(w)
factor, so it just ends up being a weighted avg with w.
the second one comes from the fact that we orthogonalize wrt to a weighted
scalar product

On Friday, October 23, 2015, Manoj Kumar notifications@github.com wrote:

I think I agree with @giorgiop https://github.com/giorgiop

More simply, if we just take the loss function as \sum_{i=n} weight * (y

  • X'w -c)**2 then even the intercept gets weighted by a sqrt(weight) term
    no?


Reply to this email directly or view it on GitHub
#5526 (comment)
.

@giorgiop
Copy link
Contributor Author

hmm, shouldn't this also show up in @ainafp 's common tests?

I don't know. I had a look at the test but I don't have an answer just yet.

@giorgiop giorgiop force-pushed the fix-weight-linearregression branch 2 times, most recently from 2c2897a to c7e276a Compare October 23, 2015 10:07
@giorgiop
Copy link
Contributor Author

So, LinearRegression does not break with the kind of tests of #5515 @eickenberg @ainafb

import numpy as np
from sklearn.linear_model.base import LinearRegression

rng = np.random.RandomState(0)
y = rng.randn(6)
X = rng.randn(6, 5)
w = 2 * np.ones(6)

clf = LinearRegression()
clf.fit(X, y, sample_weight=w)
coefs1 = clf.coef_

scaled_y = np.hstack((y, y))
scaled_X = np.vstack((X, X))
clf.fit(scaled_X, scaled_y)
coefs2 = clf.coef_

coefs1
array([ 4.58791686, -4.2095038 ,  0.39031788,  3.2727146 , -0.17386704])
coefs2
array([ 4.58791686, -4.2095038 ,  0.39031788,  3.2727146 , -0.17386704])

@eickenberg
Copy link
Contributor

#5515 shouldn't have a constant sample weight vector. It is drawn uniformly
from a range of integers. That should be different

On Friday, October 23, 2015, Giorgio Patrini notifications@github.com
wrote:

So, LinearRegression does not break with the kind of tests of #5515
#5515 @eickenberg
https://github.com/eickenberg @ainafb

import numpy as np
from sklearn.linear_model.base import LinearRegression

rng = np.random.RandomState(0)
y = rng.randn(6)
X = rng.randn(6, 5)
w = 2 * np.ones(6)

clf = LinearRegression()
clf.fit(X, y, sample_weight=w)
coefs1 = clf.coef_

scaled_y = np.hstack((y, y))
scaled_X = np.vstack((X, X))
clf.fit(scaled_X, scaled_y)
coefs2 = clf.coef_

coefs1
array([ 4.58791686, -4.2095038 , 0.39031788, 3.2727146 , -0.17386704])
coefs2
array([ 4.58791686, -4.2095038 , 0.39031788, 3.2727146 , -0.17386704])


Reply to this email directly or view it on GitHub
#5526 (comment)
.

@giorgiop
Copy link
Contributor Author

Using w = rng.randint(1,5,6) still makes the coefficients equal.
My point is that #5515 cannot cover the issue here and I still have to understand why.

@eickenberg
Copy link
Contributor

maybe linear regression is not among the checked estimators for some reason?

On Friday, October 23, 2015, Giorgio Patrini notifications@github.com
wrote:

My copy and paste mistake. Using w = rng.randint(1,5,6) still makes the
coefficients equal.
My point is that #5515
#5515 cannot cover the
issue here and I still have to understand why.


Reply to this email directly or view it on GitHub
#5526 (comment)
.

@giorgiop giorgiop force-pushed the fix-weight-linearregression branch 2 times, most recently from b2e6dcb to 2dcad5e Compare November 10, 2015 15:59
@giorgiop
Copy link
Contributor Author

@mblondel I think that kind of test is not support to pass. I have removed it and used the idea based on weighted least square that we discussed, with data augmentation by dummy variable. It should be all right. I am amending with some cosmetics as suggest by @agramfort in person, and squashing.

@giorgiop
Copy link
Contributor Author

@agramfort suggested also that we move the test based on weighted least square on test_common, where we may loop over all regressors that minimize square loss, fitting with very small regularization coefficients.

@eickenberg
Copy link
Contributor

That looks like it could be done within or in conjunction with the other PR
on a common test for sample weight effects

On Tue, Nov 10, 2015 at 5:01 PM, Giorgio Patrini notifications@github.com
wrote:

@agramfort https://github.com/agramfort suggested also that we move the
test based on weighted least square on test_common, where we may loop
over all regressors that minimize square loss, which very small
regularization parameters.


Reply to this email directly or view it on GitHub
#5526 (comment)
.

@giorgiop giorgiop force-pushed the fix-weight-linearregression branch from 2dcad5e to b76d5a4 Compare November 10, 2015 16:07
@giorgiop
Copy link
Contributor Author

Sure we can. It's going to follow a different logic, but has the same goal.

@agramfort
Copy link
Member

agramfort commented Nov 10, 2015 via email

@giorgiop giorgiop changed the title [MRG + 1] BUG fix sample_weight in LinearRegression [MRG + 1] Improve tests for sample_weight in LinearRegression and Ridge Nov 10, 2015
@giorgiop
Copy link
Contributor Author

let me know when I shall review

Please go ahead!


assert_array_almost_equal(coefs1, coefs2)
if intercept is False:
assert_array_almost_equal(coefs1, coefs3)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coefs3 -> coefs2

@agramfort
Copy link
Member

also please add note that the same tests are needed for sparse data

@agramfort
Copy link
Member

that's it for me!

@giorgiop
Copy link
Contributor Author

also please add note that the same tests are needed for sparse data

There is a comment already in the code for this. I will move it to a place where it is more evident.

@giorgiop giorgiop force-pushed the fix-weight-linearregression branch from b76d5a4 to 1d16ec4 Compare November 10, 2015 21:08
@agramfort
Copy link
Member

agramfort commented Nov 10, 2015 via email

@giorgiop
Copy link
Contributor Author

Comments should have been addressed.

@mblondel
Copy link
Member

I have removed it and used the idea based on weighted least square that we discussed, with data augmentation by dummy variable.

Indeed this is a much more solid test. Thanks for that!

agramfort added a commit that referenced this pull request Nov 11, 2015
[MRG + 1] Improve tests for sample_weight in LinearRegression and Ridge
@agramfort agramfort merged commit a5d6144 into scikit-learn:master Nov 11, 2015
@agramfort
Copy link
Member

thanks @giorgiop !

@giorgiop giorgiop deleted the fix-weight-linearregression branch November 11, 2015 17:40
rng = np.random.RandomState(0)

for n_samples, n_features in ((6, 5), (5, 10)):
# It would not work with under-determined systems
for n_samples, n_features in ((6, 5), ):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this loop is pretty useless now, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is, my bad.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants