diff --git a/sklearn/decomposition/fastica_.py b/sklearn/decomposition/fastica_.py index 01f47a2789992..d4c1c61e3a0b7 100644 --- a/sklearn/decomposition/fastica_.py +++ b/sklearn/decomposition/fastica_.py @@ -394,6 +394,7 @@ def __init__(self, n_components=None, algorithm='parallel', whiten=True, self.tol = tol self.w_init = w_init self.random_state = random_state + self.compute_sources = compute_sources def fit_transform(self, X, y=None): """Fit the model and recover the sources from X. @@ -413,7 +414,8 @@ def fit_transform(self, X, y=None): X=X, n_components=self.n_components, algorithm=self.algorithm, whiten=self.whiten, fun=self.fun, fun_args=fun_args, max_iter=self.max_iter, tol=self.tol, w_init=self.w_init, - random_state=self.random_state, return_X_mean=True) + random_state=self.random_state, return_X_mean=True, + compute_sources=self.compute_sources) if self.whiten: self.components_ = np.dot(unmixing_, whitening_) self.mean_ = X_mean @@ -422,7 +424,10 @@ def fit_transform(self, X, y=None): self.components_ = unmixing_ self.mixing_ = linalg.pinv(self.components_) - self.sources_ = sources_ + + if self.compute_sources: + self.sources_ = sources_ + return sources_ def fit(self, X, y=None): diff --git a/sklearn/linear_model/__init__.py b/sklearn/linear_model/__init__.py index 9c609c3918793..fa099b202bd7d 100644 --- a/sklearn/linear_model/__init__.py +++ b/sklearn/linear_model/__init__.py @@ -22,7 +22,8 @@ from .ridge import (Ridge, RidgeCV, RidgeClassifier, RidgeClassifierCV, ridge_regression) from .logistic import LogisticRegression -from .omp import orthogonal_mp, orthogonal_mp_gram, OrthogonalMatchingPursuit +from .omp import (orthogonal_mp, orthogonal_mp_gram, OrthogonalMatchingPursuit, + OrthogonalMatchingPursuitCV) from .passive_aggressive import PassiveAggressiveClassifier from .passive_aggressive import PassiveAggressiveRegressor from .perceptron import Perceptron @@ -49,6 +50,7 @@ 'MultiTaskElasticNet', 'MultiTaskLasso', 'OrthogonalMatchingPursuit', + 'OrthogonalMatchingPursuitCV', 'PassiveAggressiveClassifier', 'PassiveAggressiveRegressor', 'Perceptron', diff --git a/sklearn/linear_model/omp.py b/sklearn/linear_model/omp.py index ffb30eb94f279..e344ab1a8bd8c 100644 --- a/sklearn/linear_model/omp.py +++ b/sklearn/linear_model/omp.py @@ -13,7 +13,9 @@ from .base import LinearModel from ..base import RegressorMixin -from ..utils import array2d +from ..utils import array2d, as_float_array +from ..cross_validation import check_cv +from ..externals.joblib import Parallel, delayed from ..utils.arrayfuncs import solve_triangular premature = """ Orthogonal matching pursuit ended prematurely due to linear @@ -21,7 +23,8 @@ """ -def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, copy_X=True): +def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, copy_X=True, + return_path=False): """Orthogonal Matching Pursuit step using the Cholesky decomposition. Parameters @@ -45,6 +48,10 @@ def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, copy_X=True): Returns ------- + return_path: bool, optional. Default: False + Whether to return every value of the nonzero coefficients along the + forward path. Useful for cross-validation. + gamma: array, shape = (n_nonzero_coefs,) Non-zero elements of the solution @@ -52,6 +59,11 @@ def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, copy_X=True): Indices of the positions of the elements in gamma within the solution vector + coefs, array, shape = (n_features, n_nonzero_coefs) + The first k values of column k correspond to the coefficient value + for the active features at that step. The lower left triangle contains + garbage. Only returned if ``return_path=True``. + """ if copy_X: X = X.copy('F') @@ -71,6 +83,8 @@ def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, copy_X=True): max_features = X.shape[1] if tol is not None else n_nonzero_coefs L = np.empty((max_features, max_features), dtype=X.dtype) L[0, 0] = 1. + if return_path: + coefs = np.empty_like(L) while True: lam = np.argmax(np.abs(np.dot(X.T, residual))) @@ -94,18 +108,22 @@ def _cholesky_omp(X, y, n_nonzero_coefs, tol=None, copy_X=True): # solves LL'x = y as a composition of two triangular systems gamma, _ = potrs(L[:n_active, :n_active], alpha[:n_active], lower=True, overwrite_b=False) - + if return_path: + coefs[:n_active, n_active - 1] = gamma residual = y - np.dot(X[:, :n_active], gamma) if tol is not None and nrm2(residual) ** 2 <= tol: break elif n_active == max_features: break - return gamma, indices[:n_active] + if return_path: + return gamma, indices[:n_active], coefs[:, :n_active] + else: + return gamma, indices[:n_active] def _gram_omp(Gram, Xy, n_nonzero_coefs, tol_0=None, tol=None, - copy_Gram=True, copy_Xy=True): + copy_Gram=True, copy_Xy=True, return_path=False): """Orthogonal Matching Pursuit step on a precomputed Gram matrix. This function uses the the Cholesky decomposition method. @@ -138,6 +156,10 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, tol_0=None, tol=None, Returns ------- + return_path: bool, optional. Default: False + Whether to return every value of the nonzero coefficients along the + forward path. Useful for cross-validation. + gamma: array, shape = (n_nonzero_coefs,) Non-zero elements of the solution @@ -145,6 +167,11 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, tol_0=None, tol=None, Indices of the positions of the elements in gamma within the solution vector + coefs, array, shape = (n_features, n_nonzero_coefs) + The first k values of column k correspond to the coefficient value + for the active features at that step. The lower left triangle contains + garbage. Only returned if ``return_path=True``. + """ Gram = Gram.copy('F') if copy_Gram else np.asfortranarray(Gram) @@ -165,6 +192,8 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, tol_0=None, tol=None, max_features = len(Gram) if tol is not None else n_nonzero_coefs L = np.empty((max_features, max_features), dtype=Gram.dtype) L[0, 0] = 1. + if return_path: + coefs = np.empty_like(L) while True: lam = np.argmax(np.abs(alpha)) @@ -188,7 +217,8 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, tol_0=None, tol=None, # solves LL'x = y as a composition of two triangular systems gamma, _ = potrs(L[:n_active, :n_active], Xy[:n_active], lower=True, overwrite_b=False) - + if return_path: + coefs[:n_active, n_active - 1] = gamma beta = np.dot(Gram[:, :n_active], gamma) alpha = Xy - beta if tol is not None: @@ -200,11 +230,14 @@ def _gram_omp(Gram, Xy, n_nonzero_coefs, tol_0=None, tol=None, elif n_active == max_features: break - return gamma, indices[:n_active] + if return_path: + return gamma, indices[:n_active], coefs[:, :n_active] + else: + return gamma, indices[:n_active] def orthogonal_mp(X, y, n_nonzero_coefs=None, tol=None, precompute_gram=False, - copy_X=True): + copy_X=True, return_path=False): """Orthogonal Matching Pursuit (OMP) Solves n_targets Orthogonal Matching Pursuit problems. @@ -241,10 +274,18 @@ def orthogonal_mp(X, y, n_nonzero_coefs=None, tol=None, precompute_gram=False, value is only helpful if X is already Fortran-ordered, otherwise a copy is made anyway. + return_path: bool, optional. Default: False + Whether to return every value of the nonzero coefficients along the + forward path. Useful for cross-validation. + Returns ------- coef: array, shape = (n_features,) or (n_features, n_targets) - Coefficients of the OMP solution + Coefficients of the OMP solution. If `return_path=True`, this contains + the whole coefficient path. In this case its shape is + (n_features, n_features) or (n_features, n_targets, n_features) and + iterating over the last axis yields coefficients in increasing order + of active features. See also -------- @@ -297,17 +338,28 @@ def orthogonal_mp(X, y, n_nonzero_coefs=None, tol=None, precompute_gram=False, return orthogonal_mp_gram(G, Xy, n_nonzero_coefs, tol, norms_squared, copy_Gram=copy_X, copy_Xy=False) - coef = np.zeros((X.shape[1], y.shape[1])) - for k in range(y.shape[1]): - x, idx = _cholesky_omp(X, y[:, k], n_nonzero_coefs, tol, - copy_X=copy_X) - coef[idx, k] = x + if return_path: + coef = np.zeros((X.shape[1], y.shape[1], X.shape[1])) + else: + coef = np.zeros((X.shape[1], y.shape[1])) + + for k in xrange(y.shape[1]): + out = _cholesky_omp(X, y[:, k], n_nonzero_coefs, tol, + copy_X=copy_X, return_path=return_path) + if return_path: + _, idx, coefs = out + coef = coef[:, :, :len(idx)] + for n_active, x in enumerate(coefs.T): + coef[idx[:n_active + 1], k, n_active] = x[:n_active + 1] + else: + x, idx = out + coef[idx, k] = x return np.squeeze(coef) def orthogonal_mp_gram(Gram, Xy, n_nonzero_coefs=None, tol=None, norms_squared=None, copy_Gram=True, - copy_Xy=True): + copy_Xy=True, return_path=False): """Gram Orthogonal Matching Pursuit (OMP) Solves n_targets Orthogonal Matching Pursuit problems using only @@ -340,10 +392,18 @@ def orthogonal_mp_gram(Gram, Xy, n_nonzero_coefs=None, tol=None, Whether the covariance vector Xy must be copied by the algorithm. If False, it may be overwritten. + return_path: bool, optional. Default: False + Whether to return every value of the nonzero coefficients along the + forward path. Useful for cross-validation. + Returns ------- coef: array, shape = (n_features,) or (n_features, n_targets) - Coefficients of the OMP solution + Coefficients of the OMP solution. If `return_path=True`, this contains + the whole coefficient path. In this case its shape is + (n_features, n_features) or (n_features, n_targets, n_features) and + iterating over the last axis yields coefficients in increasing order + of active features. See also -------- @@ -387,12 +447,26 @@ def orthogonal_mp_gram(Gram, Xy, n_nonzero_coefs=None, tol=None, if tol is None and n_nonzero_coefs > len(Gram): raise ValueError("The number of atoms cannot be more than the number " "of features") - coef = np.zeros((len(Gram), Xy.shape[1])) + + if return_path: + coef = np.zeros((len(Gram), Xy.shape[1], len(Gram))) + else: + coef = np.zeros((len(Gram), Xy.shape[1])) + for k in range(Xy.shape[1]): - x, idx = _gram_omp(Gram, Xy[:, k], n_nonzero_coefs, - norms_squared[k] if tol is not None else None, tol, - copy_Gram=copy_Gram, copy_Xy=copy_Xy) - coef[idx, k] = x + out = _gram_omp(Gram, Xy[:, k], n_nonzero_coefs, + norms_squared[k] if tol is not None else None, tol, + copy_Gram=copy_Gram, copy_Xy=copy_Xy, + return_path=return_path) + if return_path: + _, idx, coefs = out + coef = coef[:, :, :len(idx)] + for n_active, x in enumerate(coefs.T): + coef[idx[:n_active + 1], k, n_active] = x[:n_active + 1] + else: + x, idx = out + coef[idx, k] = x + return np.squeeze(coef) @@ -561,3 +635,184 @@ def fit(self, X, y, Gram=None, Xy=None): self._set_intercept(X_mean, y_mean, X_std) return self + + +def _omp_path_residues(X_train, y_train, X_test, y_test, copy=True, + fit_intercept=True, normalize=True, max_iter=100): + """Compute the residues on left-out data for a full LARS path + + Parameters + ----------- + X_train : array, shape (n_samples, n_features) + The data to fit the LARS on + + y_train : array, shape (n_samples) + The target variable to fit LARS on + + X_test : array, shape (n_samples, n_features) + The data to compute the residues on + + y_test : array, shape (n_samples) + The target variable to compute the residues on + + copy : boolean, optional + Whether X_train, X_test, y_train and y_test should be copied. If + False, they may be overwritten. + + fit_intercept : boolean + whether to calculate the intercept for this model. If set + to false, no intercept will be used in calculations + (e.g. data is expected to be already centered). + + normalize : boolean, optional, default False + If True, the regressors X will be normalized before regression. + + max_iter : integer, optional + Maximum numbers of iterations to perform, therefore maximum features + to include. 100 by default. + + Returns + ------- + residues: array, shape [n_samples, max_features] + Residues of the prediction on the test data + """ + + if copy: + X_train = X_train.copy() + y_train = y_train.copy() + X_test = X_test.copy() + y_test = y_test.copy() + + if fit_intercept: + X_mean = X_train.mean(axis=0) + X_train -= X_mean + X_test -= X_mean + y_mean = y_train.mean(axis=0) + y_train = as_float_array(y_train, copy=False) + y_train -= y_mean + y_test = as_float_array(y_test, copy=False) + y_test -= y_mean + + if normalize: + norms = np.sqrt(np.sum(X_train ** 2, axis=0)) + nonzeros = np.flatnonzero(norms) + X_train[:, nonzeros] /= norms[nonzeros] + + coefs = orthogonal_mp(X_train, y_train, n_nonzero_coefs=max_iter, tol=None, + precompute_gram=False, copy_X=False, + return_path=True) + if coefs.ndim == 1: + coefs = coefs[:, np.newaxis] + if normalize: + coefs[nonzeros] /= norms[nonzeros][:, np.newaxis] + + return np.dot(coefs.T, X_test.T) - y_test + + +class OrthogonalMatchingPursuitCV(LinearModel, RegressorMixin): + """Cross-validated Orthogonal Mathching Pursuit model (OMP) + + Parameters + ---------- + copy : bool, optional + Whether the design matrix X must be copied by the algorithm. A false + value is only helpful if X is already Fortran-ordered, otherwise a + copy is made anyway. + + fit_intercept : boolean, optional + whether to calculate the intercept for this model. If set + to false, no intercept will be used in calculations + (e.g. data is expected to be already centered). + + normalize : boolean, optional + If False, the regressors X are assumed to be already normalized. + + max_iter : integer, optional + Maximum numbers of iterations to perform, therefore maximum features + to include. 10% of ``n_features`` but at least 1 by default. + + cv : cross-validation generator, optional + see :mod:`sklearn.cross_validation`. If ``None`` is passed, default to + a 5-fold strategy + + n_jobs : integer, optional + Number of CPUs to use during the cross validation. If ``-1``, use + all the CPUs + + verbose : boolean or integer, optional + Sets the verbosity amount + + Attributes + ---------- + `n_nonzero_coefs_` : int + Estimated number of non-zero coefficients giving the best mean + squared error over the cross-validation folds. + + `coef_` : array, shape = (n_features,) or (n_features, n_targets) + parameter vector (w in the problem formulation). + + `intercept_` : float or array, shape = (n_targets,) + independent term in decision function. + + See also + -------- + orthogonal_mp + orthogonal_mp_gram + lars_path + Lars + LassoLars + OrthogonalMatchingPursuit + LarsCV + LassoLarsCV + decomposition.sparse_encode + + """ + def __init__(self, copy=True, fit_intercept=True, normalize=True, + max_iter=None, cv=None, n_jobs=1, verbose=False): + self.copy = copy + self.fit_intercept = fit_intercept + self.normalize = normalize + self.max_iter = max_iter + self.cv = cv + self.n_jobs = n_jobs + self.verbose = verbose + + def fit(self, X, y): + """Fit the model using X, y as training data. + + Parameters + ---------- + X : array-like, shape = [n_samples, n_features] + Training data. + + y : array-like, shape = [n_samples] + Target values. + + Returns + ------- + self : object + returns an instance of self. + """ + X = array2d(X) + cv = check_cv(self.cv, X, y, classifier=False) + max_iter = (max(int(0.1 * X.shape[1]), 1) if not self.max_iter + else self.max_iter) + cv_paths = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)( + delayed(_omp_path_residues)( + X[train], y[train], X[test], y[test], self.copy, + self.fit_intercept, self.normalize, max_iter) + for train, test in cv) + + min_early_stop = min(fold.shape[0] for fold in cv_paths) + mse_folds = np.array([(fold[:min_early_stop] ** 2).mean(axis=1) + for fold in cv_paths]) + best_n_nonzero_coefs = np.argmin(mse_folds.mean(axis=0)) + 1 + self.n_nonzero_coefs_ = best_n_nonzero_coefs + omp = OrthogonalMatchingPursuit(n_nonzero_coefs=best_n_nonzero_coefs, + copy_X=self.copy, + fit_intercept=self.fit_intercept, + normalize=self.normalize) + omp.fit(X, y) + self.coef_ = omp.coef_ + self.intercept_ = omp.intercept_ + return self diff --git a/sklearn/linear_model/tests/test_omp.py b/sklearn/linear_model/tests/test_omp.py index 8425fd44d8f26..565553b2d726c 100644 --- a/sklearn/linear_model/tests/test_omp.py +++ b/sklearn/linear_model/tests/test_omp.py @@ -13,7 +13,8 @@ from sklearn.utils.testing import assert_greater from sklearn.linear_model import (orthogonal_mp, orthogonal_mp_gram, - OrthogonalMatchingPursuit) + OrthogonalMatchingPursuit, + OrthogonalMatchingPursuitCV) from sklearn.utils.fixes import count_nonzero from sklearn.datasets import make_sparse_coded_signal @@ -186,3 +187,28 @@ def test_no_atoms(): gamma_empty_gram = orthogonal_mp_gram(G, Xy_empty, 1) assert_equal(np.all(gamma_empty == 0), True) assert_equal(np.all(gamma_empty_gram == 0), True) + + +def test_omp_path(): + path = orthogonal_mp(X, y, n_nonzero_coefs=5, return_path=True) + last = orthogonal_mp(X, y, n_nonzero_coefs=5, return_path=False) + assert_equal(path.shape, (n_features, n_targets, 5)) + assert_array_almost_equal(path[:, :, -1], last) + path = orthogonal_mp_gram(G, Xy, n_nonzero_coefs=5, return_path=True) + last = orthogonal_mp_gram(G, Xy, n_nonzero_coefs=5, return_path=False) + assert_equal(path.shape, (n_features, n_targets, 5)) + assert_array_almost_equal(path[:, :, -1], last) + + +def test_omp_cv(): + y_ = y[:, 0] + gamma_ = gamma[:, 0] + ompcv = OrthogonalMatchingPursuitCV(normalize=True, fit_intercept=False, + max_iter=10, cv=5) + ompcv.fit(X, y_) + assert_equal(ompcv.n_nonzero_coefs_, n_nonzero_coefs) + assert_array_almost_equal(ompcv.coef_, gamma_) + omp = OrthogonalMatchingPursuit(normalize=True, fit_intercept=False, + n_nonzero_coefs=ompcv.n_nonzero_coefs_) + omp.fit(X, y_) + assert_array_almost_equal(ompcv.coef_, omp.coef_)