Skip to content

GridSearchCV should accept factory functions for estimators #4949

Closed
@lukauskas

Description

@lukauskas

According to the current documentation, GridSearchCV accepts object type that implements the “fit” and “predict” methods as the estimator parameter.
While fine for most, certain use cases are made quite unintuitive by this API.

For instance, consider the AdaBoostClassifier API. Essentially this classifier just wraps the boosting around whatever classifier is provided by base_estimator parameter. Most of the parameter tuning therefore happens in this base_estimator rather than the booster itself. If I were to use grid search for parameter tuning, I would probably do something among the lines of:

base_estimators = [DecisionTreeClassifier(max_depth=d) for d in range(1, 11)]
grid = GridSearchCV(AdaBoostClassifier(), dict(base_estimator=base_estimators))

Which is already quite ugly, and I am only tuning the max_depth parameter. Imagine if I also wanted to tune some other parameter in DecisionTreeClassifier class.

One way to fix this is to make GridSearchCV accept factory functions for classifiers, and not only the classifiers themselves.

Particularly, something among the lines could make things a bit easier:

def ada_factory(*args, **kwargs):
     return AdaBoostClassifier(DecisionTreeClassifier(*args, **kwargs))
grid = GridSearchCV(ada_factory, dict(max_depth=range(1,11))

Obviously, the contract where the objects returned from the factory function contain fit, and predict methods should remain in place.

Not only does this solve this particular problem, it would also allow one to test multiple estimators within the same grid search -- just add a parameter to your factory function.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions