Description
According to the current documentation, GridSearchCV
accepts object type that implements the “fit” and “predict” methods as the estimator
parameter.
While fine for most, certain use cases are made quite unintuitive by this API.
For instance, consider the AdaBoostClassifier
API. Essentially this classifier just wraps the boosting around whatever classifier is provided by base_estimator
parameter. Most of the parameter tuning therefore happens in this base_estimator
rather than the booster itself. If I were to use grid search for parameter tuning, I would probably do something among the lines of:
base_estimators = [DecisionTreeClassifier(max_depth=d) for d in range(1, 11)]
grid = GridSearchCV(AdaBoostClassifier(), dict(base_estimator=base_estimators))
Which is already quite ugly, and I am only tuning the max_depth
parameter. Imagine if I also wanted to tune some other parameter in DecisionTreeClassifier
class.
One way to fix this is to make GridSearchCV
accept factory functions for classifiers, and not only the classifiers themselves.
Particularly, something among the lines could make things a bit easier:
def ada_factory(*args, **kwargs):
return AdaBoostClassifier(DecisionTreeClassifier(*args, **kwargs))
grid = GridSearchCV(ada_factory, dict(max_depth=range(1,11))
Obviously, the contract where the objects returned from the factory function contain fit
, and predict
methods should remain in place.
Not only does this solve this particular problem, it would also allow one to test multiple estimators within the same grid search -- just add a parameter to your factory function.