Skip to content

FEAT - Basic GeneralizedLinearEstimatorCV #311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

floriankozikowski
Copy link
Contributor

@floriankozikowski floriankozikowski commented May 22, 2025

Context of the PR

Implements a cross-validation wrapper for GeneralizedLinearEstimator along with a standalone example script. Provides hyperparameter grids over alpha and optional l1_ratio, custom k-fold splitting, warm starts, scoring, and final full-data refit (closes Issue #308 ).

Contributions of the PR

  • class inherits from GeneralizedLinearEstimator, CV inspired by Celer
  • full pass over l1_ratio (if available) and alpha grids
  • custom CV via _kfold_split, yielding train/test splits without relying on sklearn’s built-in routines
  • warm_start: carry forward each fold’s solution (w) as the initial w_start for the next alpha
  • score & select: compute MSE (or user-provided scorer) per fold, track the mean loss in mse_path, and pick the (alpha_, l1_ratio_) with lowest average loss
  • final refit: once the best hyperparameters are found, update self.penalty in place and call super().fit(X, y) to train the returned model on the entire dataset
  • plot_generalized_linear_estimator.cv provides a simple example file

Consideration: - remove alpha_max, implement scikit cv splits instead of own functions, remove other PRs once final

Checks before merging PR

  • added documentation for any new feature
  • added unit tests
  • edited the what's new (if applicable)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a better place would be skglm/cv.py IMO, at least this is not related to penalties


def _score(self, y_true, y_pred):
"""Compute the loss or performance score (lower is better)."""
if hasattr(self.datafit, "loss"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

usually we don't use the datafit loss as score, i.e. HuberRegressor score is not the huber loss :
https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.HuberRegressor.html

I'd use negative mse for regression datafits and accuracy for classification datafits

return -float(scorer._score_func(y_true, y_pred, **scorer._kwargs))

if isinstance(self.datafit, (Logistic, QuadraticSVC)):
return -float(np.mean(y_true == (y_pred > 0)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why using float here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and also, why comparing to y_pred > 0 and not y_pred?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float I honestly just used for convenience, but I could also call .item() if you prefer that.
For y_pred, you are right I will change it, its redundant as self.predict(X) already returns class labels.

Otherwise I could just use sklearn.metrics.accuracy_score and mean_squared_error if thats better

def fit(self, X, y):
"""Fit the model using cross-validation."""
if not hasattr(self.penalty, "alpha"):
raise ValueError("'penalty' must expose an 'alpha' hyper-parameter.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use somehting like "GLECV only support penalties which expose an alpha parameter"

"""Fit the model using cross-validation."""
if not hasattr(self.penalty, "alpha"):
raise ValueError("'penalty' must expose an 'alpha' hyper-parameter.")
X, y = np.asarray(X), np.asarray(y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this ? this breaks sparsity if user passes sparse scipy matrix

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, I remove np.asarray(X), but I keep y since targets are always dense, or am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all preprocessing of design matrix and target is done in GeneralizedLinearEstimator.fit. You need to leverage tthis and not do anything here

pen = type(self.penalty)(alpha=alpha, **pen_kwargs)

kw = dict(X=X[train], y=y[train], datafit=self.datafit, penalty=pen)
if 'w' in self.solver.solve.__code__.co_varnames:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this do ? this seems overly complex

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was originally checking if the solver’s solve method accepted a w argument, and only passed it if it was there. I reviewed the solvers in the library again and realized that all of them always use w_init as the argument for warm starts. So we don’t need this dynamic check and can always pass w_init directly

if 'w' in self.solver.solve.__code__.co_varnames:
kw['w'] = w_start
w = self.solver.solve(**kw)
w = w[0] if isinstance(w, tuple) else w
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

which solvers return a tuple ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe I am misunderstanding something, but most (if not all solvers) return a tuple in the solve method.
E.g. in Fista: return w, np.array(p_objs_out), stop_crit

The first element is always the weights/solution.
The check w = w[0] if isinstance(w, tuple) else w is not strictly necessary, but its to make sure in case a future solver returns just the weights. I can change it to w, *_ if you want to have it shorter and cleaner and we know that we always have tuples returning for solvers.


coef, intercept = (w[:p], w[p]) if w.size == p + 1 else (w, 0.0)

y_pred = X[test] @ coef + intercept
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for logistic regression the prediction is not this, it's the sign of this vector

you can delegate everything to a GeneralizedEstimator (creating it by passing datafit, penalty, solver, then changing the alpha attribute and using warm_start=True)

self.mse_path_ = mse_path
return self

def predict(self, X):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delegate to an instance of GeneralizedLinearEstimator instead

super().fit(X, y)
self.alphas_ = alphas
self.mse_path_ = mse_path
return self
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Look at how LassoCV works: after fitting, you can directly call it to predict, because it stores the best Lasso. you need to do the same, storing the best model coefficient, the best best_alpha_ and recreating a GeneralizedLienarModel at the end, that fits on the whole dataset X:

look at how it's done here :
https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/linear_model/_coordinate_descent.py#L1882

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants