Skip to content

Fix gradient boosting quantile regression #18849

Closed
@Bougeant

Description

@Bougeant

Describe the workflow you want to enable

The quantile loss function used for the Gradient Boosting Classifier is too conservative in its predictions for extreme values.

This makes the quantile regression almost equivalent to looking up the dataset's quantile, which is not really useful.

Describe your proposed solution

Use the same type of loss function as in the scikit-garden package.

Describe alternatives you've considered, if relevant

When the GB classifier is overfitting, this behavior seems to be going away.

Additional context

import pandas as pd
from sklearn.datasets import load_boston
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from skgarden import RandomForestQuantileRegressor

data = load_boston()
X = pd.DataFrame(data=data["data"], columns=data["feature_names"])
y = pd.Series(data=data["target"])

# with sklearn:
gb_learn = GradientBoostingRegressor(loss="quantile", n_estimators=20, max_depth=10)

gb_learn.set_params(alpha=0.5)
gb_learn.fit(X, y)
pred_learn_median = gb_learn.predict(X)
gb_learn.set_params(alpha=0.05)
gb_learn.fit(X, y)
pred_learn_m_ci = gb_learn.predict(X)
gb_learn.set_params(alpha=0.95)
gb_learn.fit(X, y)
pred_learn_p_ci = gb_learn.predict(X)

fig = plt.figure(figsize=(12, 8))
sns.scatterplot(x=y, y=pred_learn_median, label="Median")
sns.scatterplot(x=y, y=pred_learn_m_ci, label="5% quantile")
sns.scatterplot(x=y, y=pred_learn_p_ci, label="95% quantile")
plt.plot([0, 50], [0, 50], c="red")
sns.despine()
plt.xlabel("True value")
plt.ylabel("Predicted value")
plt.show()

# with skgarden
rf_garden = RandomForestQuantileRegressor(n_estimators=20, max_depth=3)
pred_garden_median = rf_garden.predict(X, quantile=50)
pred_garden_m_ci = rf_garden.predict(X, quantile=5)
pred_garden_p_ci = rf_garden.predict(X, quantile=95)

fig = plt.figure(figsize=(12, 8))
sns.scatterplot(x=y, y=pred_garden_median, label="Median")
sns.scatterplot(x=y, y=pred_garden_m_ci, label="5% quantile")
sns.scatterplot(x=y, y=pred_garden_p_ci, label="95% quantile")
plt.plot([0, 50], [0, 50], c="red")
sns.despine()
plt.xlabel("True value")
plt.ylabel("Predicted value")
plt.show()

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions