Skip to content

[WIP] sample_weight support #1574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 46 additions & 16 deletions sklearn/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,8 @@ def __len__(self):

##############################################################################


def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
def cross_val_score(estimator, X, y=None, sample_weight=None,
scoring=None, cv=None, n_jobs=1,
verbose=0, fit_params=None, score_func=None,
pre_dispatch='2*n_jobs'):
"""Evaluate a score by cross-validation
Expand All @@ -1086,6 +1086,9 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
The target variable to try to predict in the case of
supervised learning.

sample_weight : array-like, optional, default: None
Sample weights.

scoring : string, callable or None, optional, default: None
A string (see model evaluation documentation) or
a scorer callable object / function with signature
Expand Down Expand Up @@ -1128,25 +1131,31 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
scores : array of float, shape=(len(list(cv)),)
Array of scores of the estimator for each run of the cross validation.
"""
X, y = check_arrays(X, y, sparse_format='csr', allow_lists=True,
allow_nans=True)
X, y, sample_weight = check_arrays(X, y, sample_weight,
sparse_format='csr',
allow_lists=True,
allow_nans=True)
if y is not None:
y = np.asarray(y)
if sample_weight is not None:
sample_weight = np.asarray(sample_weight)

cv = _check_cv(cv, X, y, classifier=is_classifier(estimator))
scorer = check_scoring(estimator, score_func=score_func, scoring=scoring)
# We clone the estimator to make sure that all the folds are
# independent, and that it is pickle-able.
parallel = Parallel(n_jobs=n_jobs, verbose=verbose,
pre_dispatch=pre_dispatch)
scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer,
scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y,
sample_weight, scorer,
train, test, verbose, None,
fit_params)
for train, test in cv)
return np.array(scores)[:, 0]


def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters,
def _fit_and_score(estimator, X, y, sample_weight,
scorer, train, test, verbose, parameters,
fit_params, return_train_score=False,
return_parameters=False):
"""Fit estimator and compute scores for a given dataset split.
Expand All @@ -1159,10 +1168,13 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters,
X : array-like of shape at least 2D
The data to fit.

y : array-like, optional, default: None
y : array-like or None
The target variable to try to predict in the case of
supervised learning.

sample_weight : array-like or None
Sample weights.

scoring : callable
A scorer callable object / function with signature
``scorer(estimator, X, y)``.
Expand Down Expand Up @@ -1225,15 +1237,28 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters,

start_time = time.time()

X_train, y_train = _safe_split(estimator, X, y, train)
X_test, y_test = _safe_split(estimator, X, y, test, train)
X_train, y_train, sample_weight_train = _safe_split(
estimator, X, y, sample_weight, train)
X_test, y_test, sample_weight_test = _safe_split(
estimator, X, y, sample_weight, test, train)

test_score_params = {}
train_score_params = {}
if sample_weight is not None:
fit_params = fit_params.copy()
fit_params['sample_weight'] = sample_weight_train
test_score_params['sample_weight'] = sample_weight_test
train_score_params['sample_weight'] = sample_weight_train

if y_train is None:
estimator.fit(X_train, **fit_params)
else:
estimator.fit(X_train, y_train, **fit_params)
test_score = _score(estimator, X_test, y_test, scorer)
test_score = _score(estimator, X_test, y_test, scorer,
**test_score_params)
if return_train_score:
train_score = _score(estimator, X_train, y_train, scorer)
train_score = _score(estimator, X_train, y_train, scorer,
**train_score_params)

scoring_time = time.time() - start_time

Expand All @@ -1250,7 +1275,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters,
return ret


def _safe_split(estimator, X, y, indices, train_indices=None):
def _safe_split(estimator, X, y, sample_weight, indices, train_indices=None):
"""Create subset of dataset and properly handle kernels."""
if hasattr(estimator, 'kernel') and callable(estimator.kernel):
# cannot compute the kernel values with custom function
Expand Down Expand Up @@ -1279,15 +1304,20 @@ def _safe_split(estimator, X, y, indices, train_indices=None):
else:
y_subset = None

return X_subset, y_subset
if sample_weight is not None:
sample_weight_subset = np.asarray(sample_weight)[indices]
else:
sample_weight_subset = None

return X_subset, y_subset, sample_weight_subset


def _score(estimator, X_test, y_test, scorer):
def _score(estimator, X_test, y_test, scorer, **params):
"""Compute the score of an estimator on a given test set."""
if y_test is None:
score = scorer(estimator, X_test)
score = scorer(estimator, X_test, **params)
else:
score = scorer(estimator, X_test, y_test)
score = scorer(estimator, X_test, y_test, **params)
if not isinstance(score, numbers.Number):
raise ValueError("scoring must return a number, got %s (%s) instead."
% (str(score), type(score)))
Expand Down
29 changes: 22 additions & 7 deletions sklearn/feature_selection/rfe.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def __init__(self, estimator, step=1, cv=None, scoring=None,
self.estimator_params = estimator_params
self.verbose = verbose

def fit(self, X, y):
def fit(self, X, y, sample_weight=None):
"""Fit the RFE model and automatically tune the number of selected
features.

Expand All @@ -318,6 +318,9 @@ def fit(self, X, y):
y : array-like, shape = [n_samples]
Target values (integers for classification, real numbers for
regression).

sample_weight : array-like, shape = [n_samples], optional (default=None)
Sample weights.
"""
X, y = check_arrays(X, y, sparse_format="csr")
# Initialization
Expand All @@ -332,17 +335,26 @@ def fit(self, X, y):

# Cross-validation
for n, (train, test) in enumerate(cv):
X_train, y_train = _safe_split(self.estimator, X, y, train)
X_test, y_test = _safe_split(self.estimator, X, y, test, train)
X_train, y_train, sample_weight_train = _safe_split(
self.estimator, X, y, sample_weight, train)
X_test, y_test, sample_weight_test = _safe_split(
self.estimator, X, y, sample_weight, test, train)

fit_params = dict()
score_params = dict()
if sample_weight is not None:
fit_params['sample_weight'] = sample_weight_train
score_params['sample_weight'] = sample_weight_test

# Compute a full ranking of the features
ranking_ = rfe.fit(X_train, y_train).ranking_
ranking_ = rfe.fit(X_train, y_train, **fit_params).ranking_
# Score each subset of features
for k in range(0, max(ranking_)):
mask = np.where(ranking_ <= k + 1)[0]
estimator = clone(self.estimator)
estimator.fit(X_train[:, mask], y_train)
score = _score(estimator, X_test[:, mask], y_test, scorer)
estimator.fit(X_train[:, mask], y_train, **fit_params)
score = _score(
estimator, X_test[:, mask], y_test, scorer, **score_params)

if self.verbose > 0:
print("Finished fold with %d / %d feature ranks, score=%f"
Expand All @@ -358,7 +370,10 @@ def fit(self, X, y):
n_features_to_select=k+1,
step=self.step, estimator_params=self.estimator_params)

rfe.fit(X, y)
if sample_weight is not None:
rfe.fit(X, y, sample_weight=sample_weight)
else:
rfe.fit(X, y)

# Set final attributes
self.support_ = rfe.support_
Expand Down
56 changes: 40 additions & 16 deletions sklearn/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# Gael Varoquaux <gael.varoquaux@normalesup.org>
# Andreas Mueller <amueller@ais.uni-bonn.de>
# Olivier Grisel <olivier.grisel@ensta.org>
# Noel Dawe <noel@dawe.me>
# License: BSD 3 clause

from abc import ABCMeta, abstractmethod
Expand Down Expand Up @@ -228,7 +229,8 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer,
n_samples_test : int
Number of test samples in this split.
"""
score, n_samples_test, _ = _fit_and_score(estimator, X, y, scorer, train,
score, n_samples_test, _ = _fit_and_score(estimator, X, y, None,
scorer, train,
test, verbose, parameters,
fit_params)
return score, parameters, n_samples_test
Expand Down Expand Up @@ -295,7 +297,7 @@ def __init__(self, estimator, scoring=None, loss_func=None,
self.verbose = verbose
self.pre_dispatch = pre_dispatch

def score(self, X, y=None):
def score(self, X, y=None, sample_weight=None):
"""Returns the score on the given test data and labels, if the search
estimator has been refit. The ``score`` function of the best estimator
is used, or the ``scoring`` parameter where unavailable.
Expand All @@ -310,18 +312,24 @@ def score(self, X, y=None):
Target relative to X for classification or regression;
None for unsupervised learning.

sample_weight : array-like, shape = [n_samples], optional
Sample weights.

Returns
-------
score : float

"""
kwargs = {}
if sample_weight is not None:
kwargs['sample_weight'] = sample_weight
if hasattr(self.best_estimator_, 'score'):
return self.best_estimator_.score(X, y)
return self.best_estimator_.score(X, y, **kwargs)
if self.scorer_ is None:
raise ValueError("No score function explicitly defined, "
"and the estimator doesn't provide one %s"
% self.best_estimator_)
return self.scorer_(self.best_estimator_, X, y)
return self.scorer_(self.best_estimator_, X, y, **kwargs)

@property
def predict(self):
Expand All @@ -339,7 +347,7 @@ def decision_function(self):
def transform(self):
return self.best_estimator_.transform

def _fit(self, X, y, parameter_iterable):
def _fit(self, X, y, sample_weight, parameter_iterable):
"""Actual fitting, performing the search over parameters."""

estimator = self.estimator
Expand All @@ -349,15 +357,21 @@ def _fit(self, X, y, parameter_iterable):
score_func=self.score_func)

n_samples = _num_samples(X)
X, y = check_arrays(X, y, allow_lists=True, sparse_format='csr',
allow_nans=True)
X, y, sample_weight = check_arrays(X, y, sample_weight,
allow_lists=True,
sparse_format='csr',
allow_nans=True)

if y is not None:
if len(y) != n_samples:
raise ValueError('Target variable (y) has a different number '
'of samples (%i) than data (X: %i samples)'
% (len(y), n_samples))
y = np.asarray(y)

if sample_weight is not None:
sample_weight = np.asarray(sample_weight)

cv = check_cv(cv, X, y, classifier=is_classifier(estimator))

if self.verbose > 0:
Expand All @@ -375,9 +389,10 @@ def _fit(self, X, y, parameter_iterable):
n_jobs=self.n_jobs, verbose=self.verbose,
pre_dispatch=pre_dispatch
)(
delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_,
train, test, self.verbose, parameters,
self.fit_params, return_parameters=True)
delayed(_fit_and_score)(clone(base_estimator), X, y, sample_weight,
self.scorer_, train, test,
self.verbose, parameters, self.fit_params,
return_parameters=True)
for parameters in parameter_iterable
for train, test in cv)

Expand Down Expand Up @@ -419,14 +434,18 @@ def _fit(self, X, y, parameter_iterable):
self.best_score_ = best.mean_validation_score

if self.refit:
fit_params = self.fit_params
if sample_weight is not None:
fit_params = fit_params.copy()
fit_params['sample_weight'] = sample_weight
# fit the best estimator using the entire dataset
# clone first to work around broken estimators
best_estimator = clone(base_estimator).set_params(
**best.parameters)
if y is not None:
best_estimator.fit(X, y, **self.fit_params)
best_estimator.fit(X, y, **fit_params)
else:
best_estimator.fit(X, **self.fit_params)
best_estimator.fit(X, **fit_params)
self.best_estimator_ = best_estimator
return self

Expand Down Expand Up @@ -581,7 +600,7 @@ def __init__(self, estimator, param_grid, scoring=None, loss_func=None,
self.param_grid = param_grid
_check_param_grid(param_grid)

def fit(self, X, y=None):
def fit(self, X, y=None, sample_weight=None):
"""Run fit with all sets of parameters.

Parameters
Expand All @@ -595,8 +614,10 @@ def fit(self, X, y=None):
Target relative to X for classification or regression;
None for unsupervised learning.

sample_weight : array-like, shape = [n_samples], optional
Sample weights.
"""
return self._fit(X, y, ParameterGrid(self.param_grid))
return self._fit(X, y, sample_weight, ParameterGrid(self.param_grid))


class RandomizedSearchCV(BaseSearchCV):
Expand Down Expand Up @@ -732,7 +753,7 @@ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None,
n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose,
pre_dispatch=pre_dispatch)

def fit(self, X, y=None):
def fit(self, X, y=None, sample_weight=None):
"""Run fit on the estimator with randomly drawn parameters.

Parameters
Expand All @@ -745,8 +766,11 @@ def fit(self, X, y=None):
Target relative to X for classification or regression;
None for unsupervised learning.

sample_weight : array-like, shape = [n_samples], optional
Sample weights.

"""
sampled_params = ParameterSampler(self.param_distributions,
self.n_iter,
random_state=self.random_state)
return self._fit(X, y, sampled_params)
return self._fit(X, y, sample_weight, sampled_params)
Loading