diff --git a/sklearn/cross_validation.py b/sklearn/cross_validation.py index ebcf4f934f043..22e370de35577 100644 --- a/sklearn/cross_validation.py +++ b/sklearn/cross_validation.py @@ -1077,7 +1077,8 @@ def __len__(self): def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, - verbose=0, fit_params=None, pre_dispatch='2*n_jobs'): + verbose=0, fit_params=None, pre_dispatch='2*n_jobs', + scorer_params=None): """Evaluate a score by cross-validation Parameters @@ -1130,6 +1131,10 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, - A string, giving an expression as a function of n_jobs, as in '2*n_jobs' + scorer_params : dict, optional + Parameters to pass to the scorer. Can be used for sample weights + and sample groups. + Returns ------- scores : array of float, shape=(len(list(cv)),) @@ -1143,15 +1148,15 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1, # independent, and that it is pickle-able. parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch) - scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, scorer, - train, test, verbose, None, - fit_params) + scores = parallel(delayed(_fit_and_score)(clone(estimator), X, y, + scorer, train, test, verbose, + None, fit_params, scorer_params) for train, test in cv) return np.array(scores)[:, 0] def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, - fit_params, return_train_score=False, + fit_params, scorer_params, return_train_score=False, return_parameters=False): """Fit estimator and compute scores for a given dataset split. @@ -1163,7 +1168,7 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, X : array-like of shape at least 2D The data to fit. - y : array-like, optional, default: None + y : array-like or None The target variable to try to predict in the case of supervised learning. @@ -1186,6 +1191,9 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, fit_params : dict or None Parameters that will be passed to ``estimator.fit``. + scorer_params : dict or None + Parameters that will be passed to the scorer. + return_train_score : boolean, optional, default: False Compute and return score on training set. @@ -1224,6 +1232,19 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, if hasattr(v, '__len__') and len(v) == n_samples else v) for k, v in fit_params.items()]) + # Same, but take both slices + scorer_params = scorer_params if scorer_params is not None else {} + train_scorer_params = dict([(k, np.asarray(v)[train] + if hasattr(v, '__len__') + and len(v) == n_samples + else v) + for k, v in scorer_params.items()]) + test_scorer_params = dict([(k, np.asarray(v)[test] + if hasattr(v, '__len__') + and len(v) == n_samples + else v) + for k, v in scorer_params.items()]) + if parameters is not None: estimator.set_params(**parameters) @@ -1231,13 +1252,16 @@ def _fit_and_score(estimator, X, y, scorer, train, test, verbose, parameters, X_train, y_train = _safe_split(estimator, X, y, train) X_test, y_test = _safe_split(estimator, X, y, test, train) + if y_train is None: estimator.fit(X_train, **fit_params) else: estimator.fit(X_train, y_train, **fit_params) - test_score = _score(estimator, X_test, y_test, scorer) + test_score = _score(estimator, X_test, y_test, scorer, + **test_scorer_params) if return_train_score: - train_score = _score(estimator, X_train, y_train, scorer) + train_score = _score(estimator, X_train, y_train, scorer, + **train_scorer_params) scoring_time = time.time() - start_time @@ -1286,12 +1310,12 @@ def _safe_split(estimator, X, y, indices, train_indices=None): return X_subset, y_subset -def _score(estimator, X_test, y_test, scorer): +def _score(estimator, X_test, y_test, scorer, **params): """Compute the score of an estimator on a given test set.""" if y_test is None: - score = scorer(estimator, X_test) + score = scorer(estimator, X_test, **params) else: - score = scorer(estimator, X_test, y_test) + score = scorer(estimator, X_test, y_test, **params) if not isinstance(score, numbers.Number): raise ValueError("scoring must return a number, got %s (%s) instead." % (str(score), type(score))) diff --git a/sklearn/feature_selection/rfe.py b/sklearn/feature_selection/rfe.py index 86c56e1f3264a..abe3caf370a26 100644 --- a/sklearn/feature_selection/rfe.py +++ b/sklearn/feature_selection/rfe.py @@ -306,7 +306,7 @@ def __init__(self, estimator, step=1, cv=None, scoring=None, self.estimator_params = estimator_params self.verbose = verbose - def fit(self, X, y): + def fit(self, X, y, sample_weight=None): """Fit the RFE model and automatically tune the number of selected features. @@ -319,6 +319,9 @@ def fit(self, X, y): y : array-like, shape = [n_samples] Target values (integers for classification, real numbers for regression). + + sample_weight : array-like, shape = [n_samples], optional (default=None) + Sample weights. """ X, y = check_X_y(X, y, "csr") # Initialization @@ -332,17 +335,27 @@ def fit(self, X, y): # Cross-validation for n, (train, test) in enumerate(cv): - X_train, y_train = _safe_split(self.estimator, X, y, train) - X_test, y_test = _safe_split(self.estimator, X, y, test, train) + X_train, y_train = _safe_split( + self.estimator, X, y, train) + X_test, y_test = _safe_split( + self.estimator, X, y, test, train) + + fit_params = dict() + score_params = dict() + if sample_weight is not None: + sample_weight = np.asarray(sample_weight) + fit_params['sample_weight'] = sample_weight[train] + score_params['sample_weight'] = sample_weight[test] # Compute a full ranking of the features - ranking_ = rfe.fit(X_train, y_train).ranking_ + ranking_ = rfe.fit(X_train, y_train, **fit_params).ranking_ # Score each subset of features for k in range(0, max(ranking_)): mask = np.where(ranking_ <= k + 1)[0] estimator = clone(self.estimator) - estimator.fit(X_train[:, mask], y_train) - score = _score(estimator, X_test[:, mask], y_test, scorer) + estimator.fit(X_train[:, mask], y_train, **fit_params) + score = _score( + estimator, X_test[:, mask], y_test, scorer, **score_params) if self.verbose > 0: print("Finished fold with %d / %d feature ranks, score=%f" @@ -358,7 +371,10 @@ def fit(self, X, y): n_features_to_select=k+1, step=self.step, estimator_params=self.estimator_params) - rfe.fit(X, y) + if sample_weight is not None: + rfe.fit(X, y, sample_weight=sample_weight) + else: + rfe.fit(X, y) # Set final attributes self.support_ = rfe.support_ diff --git a/sklearn/grid_search.py b/sklearn/grid_search.py index 280dbb32b1e54..4d7cb28816762 100644 --- a/sklearn/grid_search.py +++ b/sklearn/grid_search.py @@ -8,6 +8,7 @@ # Gael Varoquaux # Andreas Mueller # Olivier Grisel +# Noel Dawe # License: BSD 3 clause from abc import ABCMeta, abstractmethod @@ -226,7 +227,8 @@ def fit_grid_point(X, y, estimator, parameters, train, test, scorer, n_samples_test : int Number of test samples in this split. """ - score, n_samples_test, _ = _fit_and_score(estimator, X, y, scorer, train, + score, n_samples_test, _ = _fit_and_score(estimator, X, y, None, + scorer, train, test, verbose, parameters, fit_params) return score, parameters, n_samples_test @@ -279,7 +281,8 @@ class BaseSearchCV(six.with_metaclass(ABCMeta, BaseEstimator, @abstractmethod def __init__(self, estimator, scoring=None, fit_params=None, n_jobs=1, iid=True, - refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs'): + refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', + scorer_params=None): self.scoring = scoring self.estimator = estimator @@ -290,8 +293,9 @@ def __init__(self, estimator, scoring=None, self.cv = cv self.verbose = verbose self.pre_dispatch = pre_dispatch + self.scorer_params = scorer_params - def score(self, X, y=None): + def score(self, X, y=None, **scorer_params): """Returns the score on the given test data and labels, if the search estimator has been refit. The ``score`` function of the best estimator is used, or the ``scoring`` parameter where unavailable. @@ -312,12 +316,12 @@ def score(self, X, y=None): """ if hasattr(self.best_estimator_, 'score'): - return self.best_estimator_.score(X, y) + return self.best_estimator_.score(X, y, **scorer_params) if self.scorer_ is None: raise ValueError("No score function explicitly defined, " "and the estimator doesn't provide one %s" % self.best_estimator_) - return self.scorer_(self.best_estimator_, X, y) + return self.scorer_(self.best_estimator_, X, y, **scorer_params) @property def predict(self): @@ -350,6 +354,7 @@ def _fit(self, X, y, parameter_iterable): raise ValueError('Target variable (y) has a different number ' 'of samples (%i) than data (X: %i samples)' % (len(y), n_samples)) + cv = check_cv(cv, X, y, classifier=is_classifier(estimator)) if self.verbose > 0: @@ -367,9 +372,10 @@ def _fit(self, X, y, parameter_iterable): n_jobs=self.n_jobs, verbose=self.verbose, pre_dispatch=pre_dispatch )( - delayed(_fit_and_score)(clone(base_estimator), X, y, self.scorer_, - train, test, self.verbose, parameters, - self.fit_params, return_parameters=True) + delayed(_fit_and_score)(clone(base_estimator), X, y, + self.scorer_, train, test, + self.verbose, parameters, self.fit_params, + self.scorer_params, return_parameters=True) for parameters in parameter_iterable for train, test in cv) @@ -411,14 +417,15 @@ def _fit(self, X, y, parameter_iterable): self.best_score_ = best.mean_validation_score if self.refit: + fit_params = self.fit_params # fit the best estimator using the entire dataset # clone first to work around broken estimators best_estimator = clone(base_estimator).set_params( **best.parameters) if y is not None: - best_estimator.fit(X, y, **self.fit_params) + best_estimator.fit(X, y, **fit_params) else: - best_estimator.fit(X, **self.fit_params) + best_estimator.fit(X, **fit_params) self.best_estimator_ = best_estimator return self @@ -566,10 +573,11 @@ class GridSearchCV(BaseSearchCV): def __init__(self, estimator, param_grid, scoring=None, fit_params=None, n_jobs=1, iid=True, - refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs'): + refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', + scorer_params=None): super(GridSearchCV, self).__init__( estimator, scoring, fit_params, n_jobs, iid, - refit, cv, verbose, pre_dispatch) + refit, cv, verbose, pre_dispatch, scorer_params) self.param_grid = param_grid _check_param_grid(param_grid) @@ -586,7 +594,6 @@ def fit(self, X, y=None): y : array-like, shape = [n_samples] or [n_samples, n_output], optional Target relative to X for classification or regression; None for unsupervised learning. - """ return self._fit(X, y, ParameterGrid(self.param_grid)) @@ -714,7 +721,8 @@ class RandomizedSearchCV(BaseSearchCV): def __init__(self, estimator, param_distributions, n_iter=10, scoring=None, fit_params=None, n_jobs=1, iid=True, refit=True, cv=None, - verbose=0, pre_dispatch='2*n_jobs', random_state=None): + verbose=0, pre_dispatch='2*n_jobs', random_state=None, + scorer_params=None): self.param_distributions = param_distributions self.n_iter = n_iter @@ -722,7 +730,7 @@ def __init__(self, estimator, param_distributions, n_iter=10, scoring=None, super(RandomizedSearchCV, self).__init__( estimator=estimator, scoring=scoring, fit_params=fit_params, n_jobs=n_jobs, iid=iid, refit=refit, cv=cv, verbose=verbose, - pre_dispatch=pre_dispatch) + pre_dispatch=pre_dispatch, scorer_params=scorer_params) def fit(self, X, y=None): """Run fit on the estimator with randomly drawn parameters. @@ -736,7 +744,6 @@ def fit(self, X, y=None): y : array-like, shape = [n_samples] or [n_samples, n_output], optional Target relative to X for classification or regression; None for unsupervised learning. - """ sampled_params = ParameterSampler(self.param_distributions, self.n_iter, diff --git a/sklearn/learning_curve.py b/sklearn/learning_curve.py index 55c4cf6547d86..3b9bff5613561 100644 --- a/sklearn/learning_curve.py +++ b/sklearn/learning_curve.py @@ -17,7 +17,8 @@ from .utils.fixes import astype -def learning_curve(estimator, X, y, train_sizes=np.linspace(0.1, 1.0, 5), +def learning_curve(estimator, X, y, sample_weight=None, + train_sizes=np.linspace(0.1, 1.0, 10), cv=None, scoring=None, exploit_incremental_learning=False, n_jobs=1, pre_dispatch="all", verbose=0): """Learning curve. @@ -44,6 +45,9 @@ def learning_curve(estimator, X, y, train_sizes=np.linspace(0.1, 1.0, 5), Target relative to X for classification or regression; None for unsupervised learning. + sample_weight : array-like, shape (n_samples), optional + Sample weights. + train_sizes : array-like, shape (n_ticks,), dtype float or int Relative or absolute numbers of training examples that will be used to generate the learning curve. If the dtype is float, it is regarded as a @@ -128,12 +132,19 @@ def learning_curve(estimator, X, y, train_sizes=np.linspace(0.1, 1.0, 5), if exploit_incremental_learning: classes = np.unique(y) if is_classifier(estimator) else None out = parallel(delayed(_incremental_fit_estimator)( - clone(estimator), X, y, classes, train, test, train_sizes_abs, + clone(estimator), X, y, sample_weight, + classes, train, test, train_sizes_abs, scorer, verbose) for train, test in cv) else: + if sample_weight is not None: + params = dict(sample_weight=sample_weight) + else: + params = None out = parallel(delayed(_fit_and_score)( - clone(estimator), X, y, scorer, train[:n_train_samples], test, - verbose, parameters=None, fit_params=None, return_train_score=True) + clone(estimator), X, y, + scorer, train[:n_train_samples], test, + verbose, parameters=None, fit_params=params, scorer_params=params, + return_train_score=True) for train, test in cv for n_train_samples in train_sizes_abs) out = np.array(out)[:, :2] n_cv_folds = out.shape[0] // n_unique_ticks @@ -203,29 +214,51 @@ def _translate_train_sizes(train_sizes, n_max_training_samples): return train_sizes_abs -def _incremental_fit_estimator(estimator, X, y, classes, train, test, +def _incremental_fit_estimator(estimator, X, y, sample_weight, + classes, train, test, train_sizes, scorer, verbose): """Train estimator on training subsets incrementally and compute scores.""" train_scores, test_scores = [], [] partitions = zip(train_sizes, np.split(train, train_sizes)[:-1]) for n_train_samples, partial_train in partitions: train_subset = train[:n_train_samples] - X_train, y_train = _safe_split(estimator, X, y, train_subset) - X_partial_train, y_partial_train = _safe_split(estimator, X, y, - partial_train) - X_test, y_test = _safe_split(estimator, X, y, test, train_subset) + X_train, y_train = _safe_split( + estimator, X, y, train_subset) + X_partial_train, y_partial_train = \ + _safe_split(estimator, X, y, partial_train) + X_test, y_test = _safe_split( + estimator, X, y, test, train_subset) + + # TODO: replace sample_weight with fit_params and scorer_params + + fit_params = dict() + train_scorer_params = dict() + test_scorer_params = dict() + if sample_weight is not None: + sample_weight = np.asarray(sample_weight) + sample_weight_train = sample_weight[train_subset] + sample_weight_partial_train = sample_weight[partial_train] + sample_weight_test = sample_weight[test] + fit_params['sample_weight'] = sample_weight_partial_train + train_scorer_params['sample_weight'] = sample_weight_train + test_scorer_params['sample_weight'] = sample_weight_test + if y_partial_train is None: - estimator.partial_fit(X_partial_train, classes=classes) + estimator.partial_fit(X_partial_train, + classes=classes, **fit_params) else: estimator.partial_fit(X_partial_train, y_partial_train, - classes=classes) - train_scores.append(_score(estimator, X_train, y_train, scorer)) - test_scores.append(_score(estimator, X_test, y_test, scorer)) + classes=classes, **fit_params) + train_scores.append(_score( + estimator, X_train, y_train, scorer, **train_scorer_params)) + test_scores.append(_score( + estimator, X_test, y_test, scorer, **test_scorer_params)) return np.array((train_scores, test_scores)).T -def validation_curve(estimator, X, y, param_name, param_range, cv=None, - scoring=None, n_jobs=1, pre_dispatch="all", verbose=0): +def validation_curve(estimator, X, y, param_name, param_range, + sample_weight=None, cv=None, scoring=None, + n_jobs=1, pre_dispatch="all", verbose=0): """Validation curve. Determine training and test scores for varying parameter values. @@ -254,6 +287,9 @@ def validation_curve(estimator, X, y, param_name, param_range, cv=None, param_range : array-like, shape (n_values,) The values of the parameter that will be evaluated. + sample_weight : array-like, shape (n_samples,), optional + Sample weights. + cv : integer, cross-validation generator, optional If an integer is passed, it is the number of folds (defaults to 3). Specific cross-validation objects can be passed, see @@ -295,9 +331,14 @@ def validation_curve(estimator, X, y, param_name, param_range, cv=None, parallel = Parallel(n_jobs=n_jobs, pre_dispatch=pre_dispatch, verbose=verbose) + if sample_weight is not None: + params = dict(sample_weight=sample_weight) + else: + params = None out = parallel(delayed(_fit_and_score)( estimator, X, y, scorer, train, test, verbose, - parameters={param_name: v}, fit_params=None, return_train_score=True) + parameters={param_name: v}, fit_params=params, scorer_params=params, + return_train_score=True) for train, test in cv for v in param_range) out = np.asarray(out)[:, :2] diff --git a/sklearn/tests/test_cross_validation.py b/sklearn/tests/test_cross_validation.py index e3af30a1b2bae..f0ccca66da02c 100644 --- a/sklearn/tests/test_cross_validation.py +++ b/sklearn/tests/test_cross_validation.py @@ -54,9 +54,9 @@ def fit(self, X, Y=None, sample_weight=None, class_prior=None): if X.ndim >= 3 and not self.allow_nd: raise ValueError('X cannot be d') if sample_weight is not None: - assert_true(sample_weight.shape[0] == X.shape[0], + assert_true(len(sample_weight) == X.shape[0], 'MockClassifier extra fit_param sample_weight.shape[0]' - ' is {0}, should be {1}'.format(sample_weight.shape[0], + ' is {0}, should be {1}'.format(len(sample_weight), X.shape[0])) if class_prior is not None: assert_true(class_prior.shape[0] == len(np.unique(y)), @@ -70,13 +70,15 @@ def predict(self, T): T = T.reshape(len(T), -1) return T.shape[0] - def score(self, X=None, Y=None): + def score(self, X=None, Y=None, sample_weight=None): return 1. / (1 + np.abs(self.a)) X = np.ones((10, 2)) X_sparse = coo_matrix(X) y = np.arange(10) // 2 +rng = np.random.RandomState(0) +int_weights = rng.randint(10, size=y.shape) ############################################################################## # Tests @@ -466,8 +468,10 @@ def test_cross_val_score(): for a in range(-10, 10): clf.a = a # Smoke test - scores = cval.cross_val_score(clf, X, y) - assert_array_equal(scores, clf.score(X, y)) + params = dict(sample_weight=int_weights) + scores = cval.cross_val_score(clf, X, y, + fit_params=params, scorer_params=params) + assert_array_equal(scores, clf.score(X, y, sample_weight=int_weights)) # test with multioutput y scores = cval.cross_val_score(clf, X_sparse, X) @@ -480,6 +484,11 @@ def test_cross_val_score(): scores = cval.cross_val_score(clf, X_sparse, X) assert_array_equal(scores, clf.score(X_sparse, X)) + # test with sample_weight as list + params = dict(sample_weight=int_weights.tolist()) + scores = cval.cross_val_score( + clf, X, y, fit_params=params, scorer_params=params) + # test with X and y as list list_check = lambda x: isinstance(x, list) clf = CheckingClassifier(check_X=list_check) diff --git a/sklearn/tests/test_grid_search.py b/sklearn/tests/test_grid_search.py index 02183d18cd2fc..81245a6ccf22f 100644 --- a/sklearn/tests/test_grid_search.py +++ b/sklearn/tests/test_grid_search.py @@ -51,8 +51,13 @@ class MockClassifier(object): def __init__(self, foo_param=0): self.foo_param = foo_param - def fit(self, X, Y): + def fit(self, X, Y, sample_weight=None): assert_true(len(X) == len(Y)) + if sample_weight is not None: + assert_true(len(sample_weight) == len(X), + 'MockClassifier sample_weight.shape[0]' + ' is {0}, should be {1}'.format(len(sample_weight), + len(X))) return self def predict(self, T): @@ -62,7 +67,12 @@ def predict(self, T): decision_function = predict transform = predict - def score(self, X=None, Y=None): + def score(self, X=None, Y=None, sample_weight=None): + if X is not None and sample_weight is not None: + assert_true(len(sample_weight) == len(X), + 'MockClassifier sample_weight.shape[0]' + ' is {0}, should be {1}'.format(len(sample_weight), + len(X))) if self.foo_param > 1: score = 1. else: @@ -85,6 +95,7 @@ def score(self): X = np.array([[-1, -1], [-2, -1], [1, 1], [2, 1]]) y = np.array([1, 1, 2, 2]) +sample_weight = np.array([1, 2, 3, 4]) def test_parameter_grid(): @@ -638,3 +649,19 @@ def test_grid_search_allows_nans(): ('classifier', MockClassifier()), ]) GridSearchCV(p, {'classifier__foo_param': [1, 2, 3]}, cv=2).fit(X, y) + + +def test_grid_search_with_sample_weights(): + """Test grid searching with sample weights""" + est_parameters = {"foo_param": [1, 2, 3]} + cv = KFold(y.shape[0], n_folds=2, random_state=0) + for search_cls in (GridSearchCV, RandomizedSearchCV): + params=dict(sample_weight=sample_weight) + grid_search = search_cls(MockClassifier(), est_parameters, cv=cv, + fit_params=params, scorer_params=params) + grid_search.fit(X, y) + # check that sample_weight can be a list + params=dict(sample_weight=sample_weight.tolist()) + grid_search = GridSearchCV(MockClassifier(), est_parameters, cv=cv, + fit_params=params, scorer_params=params) + grid_search.fit(X, y) diff --git a/sklearn/tests/test_learning_curve.py b/sklearn/tests/test_learning_curve.py index 62a05dd19799e..c039567669274 100644 --- a/sklearn/tests/test_learning_curve.py +++ b/sklearn/tests/test_learning_curve.py @@ -25,7 +25,7 @@ def __init__(self, n_max_train_sizes): self.train_sizes = 0 self.X_subset = None - def fit(self, X_subset, y_subset=None): + def fit(self, X_subset, y_subset=None, **params): self.X_subset = X_subset self.train_sizes = X_subset.shape[0] return self @@ -65,7 +65,7 @@ def __init__(self, param=0.5): self.X_subset = None self.param = param - def fit(self, X_subset, y_subset): + def fit(self, X_subset, y_subset, **params): self.X_subset = X_subset self.train_sizes = X_subset.shape[0] return self