Skip to content

Commit 190abde

Browse files
committed
Merge pull request #4881 from sonnyhu/weighted_least_squares
[MRG + 1] add sample_weight into LinearRegression
2 parents 55c32ef + a88f6ec commit 190abde

File tree

4 files changed

+99
-19
lines changed

4 files changed

+99
-19
lines changed

sklearn/ensemble/tests/test_weight_boosting.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,13 +286,13 @@ def test_base_estimator():
286286

287287

288288
def test_sample_weight_missing():
289-
from sklearn.linear_model import LinearRegression
289+
from sklearn.linear_model import LogisticRegression
290290
from sklearn.cluster import KMeans
291291

292-
clf = AdaBoostClassifier(LinearRegression(), algorithm="SAMME")
292+
clf = AdaBoostClassifier(LogisticRegression(), algorithm="SAMME")
293293
assert_raises(ValueError, clf.fit, X, y_regr)
294294

295-
clf = AdaBoostRegressor(LinearRegression())
295+
clf = AdaBoostRegressor(LogisticRegression())
296296
assert_raises(ValueError, clf.fit, X, y_regr)
297297

298298
clf = AdaBoostClassifier(KMeans(), algorithm="SAMME")

sklearn/linear_model/base.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from ..externals import six
2626
from ..externals.joblib import Parallel, delayed
2727
from ..base import BaseEstimator, ClassifierMixin, RegressorMixin
28-
from ..utils import as_float_array, check_array, check_X_y, deprecated
28+
from ..utils import as_float_array, check_array, check_X_y, deprecated, column_or_1d
2929
from ..utils.extmath import safe_sparse_dot
3030
from ..utils.sparsefuncs import mean_variance_axis, inplace_column_scale
3131
from ..utils.fixes import sparse_lsqr
@@ -113,6 +113,18 @@ def center_data(X, y, fit_intercept, normalize=False, copy=True,
113113
return X, y, X_mean, y_mean, X_std
114114

115115

116+
def _rescale_data(X, y, sample_weight):
117+
"""Rescale data so as to support sample_weight"""
118+
n_samples = X.shape[0]
119+
sample_weight = sample_weight * np.ones(n_samples)
120+
sample_weight = np.sqrt(sample_weight)
121+
sw_matrix = sparse.dia_matrix((sample_weight, 0),
122+
shape=(n_samples, n_samples))
123+
X = safe_sparse_dot(sw_matrix, X)
124+
y = safe_sparse_dot(sw_matrix, y)
125+
return X, y
126+
127+
116128
class LinearModel(six.with_metaclass(ABCMeta, BaseEstimator)):
117129
"""Base class for Linear Models"""
118130

@@ -352,7 +364,7 @@ def __init__(self, fit_intercept=True, normalize=False, copy_X=True,
352364
self.copy_X = copy_X
353365
self.n_jobs = n_jobs
354366

355-
def fit(self, X, y):
367+
def fit(self, X, y, sample_weight=None):
356368
"""
357369
Fit linear model.
358370
@@ -364,16 +376,28 @@ def fit(self, X, y):
364376
y : numpy array of shape [n_samples, n_targets]
365377
Target values
366378
379+
sample_weight : numpy array of shape [n_samples]
380+
Individual weights for each sample
381+
367382
Returns
368383
-------
369384
self : returns an instance of self.
370385
"""
386+
371387
n_jobs_ = self.n_jobs
372388
X, y = check_X_y(X, y, accept_sparse=['csr', 'csc', 'coo'],
373389
y_numeric=True, multi_output=True)
390+
391+
if ((sample_weight is not None) and np.atleast_1d(sample_weight).ndim > 1):
392+
sample_weight = column_or_1d(sample_weight, warn=True)
374393

375394
X, y, X_mean, y_mean, X_std = self._center_data(
376-
X, y, self.fit_intercept, self.normalize, self.copy_X)
395+
X, y, self.fit_intercept, self.normalize, self.copy_X,
396+
sample_weight=sample_weight)
397+
398+
if sample_weight is not None:
399+
# Sample weight can be implemented via a simple rescaling.
400+
X, y = _rescale_data(X, y, sample_weight)
377401

378402
if sp.issparse(X):
379403
if y.ndim < 2:

sklearn/linear_model/ridge.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from scipy import sparse
1818
from scipy.sparse import linalg as sp_linalg
1919

20-
from .base import LinearClassifierMixin, LinearModel
20+
from .base import LinearClassifierMixin, LinearModel, _rescale_data
2121
from ..base import RegressorMixin
2222
from ..utils.extmath import safe_sparse_dot
2323
from ..utils import check_X_y
@@ -185,17 +185,6 @@ def _solve_svd(X, y, alpha):
185185
return np.dot(Vt.T, d_UT_y).T
186186

187187

188-
def _rescale_data(X, y, sample_weight):
189-
"""Rescale data so as to support sample_weight"""
190-
n_samples = X.shape[0]
191-
sample_weight = sample_weight * np.ones(n_samples)
192-
sample_weight = np.sqrt(sample_weight)
193-
sw_matrix = sparse.dia_matrix((sample_weight, 0),
194-
shape=(n_samples, n_samples))
195-
X = safe_sparse_dot(sw_matrix, X)
196-
y = safe_sparse_dot(sw_matrix, y)
197-
return X, y
198-
199188

200189
def ridge_regression(X, y, alpha, sample_weight=None, solver='auto',
201190
max_iter=None, tol=1e-3, verbose=0):

sklearn/linear_model/tests/test_base.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
from sklearn.utils.testing import assert_equal
1111

1212
from sklearn.linear_model.base import LinearRegression
13-
from sklearn.linear_model.base import center_data, sparse_center_data
13+
from sklearn.linear_model.base import center_data, sparse_center_data, _rescale_data
1414
from sklearn.utils import check_random_state
15+
from sklearn.utils.testing import assert_raise_message
16+
from sklearn.utils.testing import assert_greater
1517
from sklearn.datasets.samples_generator import make_sparse_uncorrelated
1618
from sklearn.datasets.samples_generator import make_regression
1719

@@ -40,6 +42,55 @@ def test_linear_regression():
4042
assert_array_almost_equal(clf.predict(X), [0])
4143

4244

45+
def test_linear_regression_sample_weights():
46+
rng = np.random.RandomState(0)
47+
48+
for n_samples, n_features in ((6, 5), (5, 10)):
49+
y = rng.randn(n_samples)
50+
X = rng.randn(n_samples, n_features)
51+
sample_weight = 1.0 + rng.rand(n_samples)
52+
53+
clf = LinearRegression()
54+
clf.fit(X, y, sample_weight)
55+
coefs1 = clf.coef_
56+
57+
assert_equal(clf.coef_.shape, (X.shape[1], ))
58+
assert_greater(clf.score(X, y), 0.9)
59+
assert_array_almost_equal(clf.predict(X), y)
60+
61+
# Sample weight can be implemented via a simple rescaling
62+
# for the square loss.
63+
scaled_y = y * np.sqrt(sample_weight)
64+
scaled_X = X * np.sqrt(sample_weight)[:, np.newaxis]
65+
clf.fit(X, y)
66+
coefs2 = clf.coef_
67+
68+
assert_array_almost_equal(coefs1, coefs2)
69+
70+
71+
def test_raises_value_error_if_sample_weights_greater_than_1d():
72+
# Sample weights must be either scalar or 1D
73+
74+
n_sampless = [2, 3]
75+
n_featuress = [3, 2]
76+
77+
rng = np.random.RandomState(42)
78+
79+
for n_samples, n_features in zip(n_sampless, n_featuress):
80+
X = rng.randn(n_samples, n_features)
81+
y = rng.randn(n_samples)
82+
sample_weights_OK = rng.randn(n_samples) ** 2 + 1
83+
sample_weights_OK_1 = 1.
84+
sample_weights_OK_2 = 2.
85+
86+
clf = LinearRegression()
87+
88+
# make sure the "OK" sample weights actually work
89+
clf.fit(X, y, sample_weights_OK)
90+
clf.fit(X, y, sample_weights_OK_1)
91+
clf.fit(X, y, sample_weights_OK_2)
92+
93+
4394
def test_fit_intercept():
4495
# Test assertions on betas shape.
4596
X2 = np.array([[0.38349978, 0.61650022],
@@ -255,3 +306,19 @@ def test_csr_sparse_center_data():
255306
csr = sparse.csr_matrix(X)
256307
csr_, y, _, _, _ = sparse_center_data(csr, y, True)
257308
assert_equal(csr_.getformat(), 'csr')
309+
310+
311+
def test_rescale_data():
312+
n_samples = 200
313+
n_features = 2
314+
315+
rng = np.random.RandomState(0)
316+
sample_weight = 1.0 + rng.rand(n_samples)
317+
X = rng.rand(n_samples, n_features)
318+
y = rng.rand(n_samples)
319+
rescaled_X, rescaled_y = _rescale_data(X, y, sample_weight)
320+
rescaled_X2 = X * np.sqrt(sample_weight)[:, np.newaxis]
321+
rescaled_y2 = y * np.sqrt(sample_weight)
322+
assert_array_almost_equal(rescaled_X, rescaled_X2)
323+
assert_array_almost_equal(rescaled_y, rescaled_y2)
324+

0 commit comments

Comments
 (0)