Closed
Description
Describe the workflow you want to enable
The quantile loss function used for the Gradient Boosting Classifier is too conservative in its predictions for extreme values.
This makes the quantile regression almost equivalent to looking up the dataset's quantile, which is not really useful.
Describe your proposed solution
Use the same type of loss function as in the scikit-garden package.
Describe alternatives you've considered, if relevant
When the GB classifier is overfitting, this behavior seems to be going away.
Additional context
import pandas as pd
from sklearn.datasets import load_boston
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import GradientBoostingRegressor
from skgarden import RandomForestQuantileRegressor
data = load_boston()
X = pd.DataFrame(data=data["data"], columns=data["feature_names"])
y = pd.Series(data=data["target"])
# with sklearn:
gb_learn = GradientBoostingRegressor(loss="quantile", n_estimators=20, max_depth=10)
gb_learn.set_params(alpha=0.5)
gb_learn.fit(X, y)
pred_learn_median = gb_learn.predict(X)
gb_learn.set_params(alpha=0.05)
gb_learn.fit(X, y)
pred_learn_m_ci = gb_learn.predict(X)
gb_learn.set_params(alpha=0.95)
gb_learn.fit(X, y)
pred_learn_p_ci = gb_learn.predict(X)
fig = plt.figure(figsize=(12, 8))
sns.scatterplot(x=y, y=pred_learn_median, label="Median")
sns.scatterplot(x=y, y=pred_learn_m_ci, label="5% quantile")
sns.scatterplot(x=y, y=pred_learn_p_ci, label="95% quantile")
plt.plot([0, 50], [0, 50], c="red")
sns.despine()
plt.xlabel("True value")
plt.ylabel("Predicted value")
plt.show()
# with skgarden
rf_garden = RandomForestQuantileRegressor(n_estimators=20, max_depth=3)
pred_garden_median = rf_garden.predict(X, quantile=50)
pred_garden_m_ci = rf_garden.predict(X, quantile=5)
pred_garden_p_ci = rf_garden.predict(X, quantile=95)
fig = plt.figure(figsize=(12, 8))
sns.scatterplot(x=y, y=pred_garden_median, label="Median")
sns.scatterplot(x=y, y=pred_garden_m_ci, label="5% quantile")
sns.scatterplot(x=y, y=pred_garden_p_ci, label="95% quantile")
plt.plot([0, 50], [0, 50], c="red")
sns.despine()
plt.xlabel("True value")
plt.ylabel("Predicted value")
plt.show()