Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 55 additions & 18 deletions examples/ensemble/plot_gradient_boosting_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,44 @@
# %%
# Load Ames Housing dataset
# -------------------------
# First, we load the ames housing data as a pandas dataframe. The features
# First, we load the Ames Housing data as a pandas dataframe. The features
# are either categorical or numerical:
from sklearn.datasets import fetch_openml

X, y = fetch_openml(data_id=41211, as_frame=True, return_X_y=True)

n_categorical_features = (X.dtypes == "category").sum()
n_numerical_features = (X.dtypes == "float").sum()
# Select only a subset of features of X to make the example faster to run
categorical_columns_subset = [
"Bldg_Type",
"Garage_Finish",
"Lot_Config",
"Functional",
"Mas_Vnr_Type",
"House_Style",
"Fireplace_Qu",
"Exter_Cond",
"Exter_Qual",
"Pool_QC",
]

numerical_columns_subset = [
"Three_season_porch",
"Fireplaces",
"Bsmt_Half_Bath",
"Half_Bath",
"Garage_Cars",
"TotRms_AbvGrd",
"BsmtFin_SF_1",
"BsmtFin_SF_2",
"Gr_Liv_Area",
"Screen_Porch",
]

X = X[categorical_columns_subset + numerical_columns_subset]

n_categorical_features = X.select_dtypes(include="category").shape[1]
n_numerical_features = X.select_dtypes(include="number").shape[1]

print(f"Number of samples: {X.shape[0]}")
print(f"Number of features: {X.shape[1]}")
print(f"Number of categorical features: {n_categorical_features}")
Expand Down Expand Up @@ -114,6 +144,7 @@

# The ordinal encoder will first output the categorical features, and then the
# continuous (passed-through) features

categorical_mask = [True] * n_categorical_features + [False] * n_numerical_features
hist_native = make_pipeline(
ordinal_encoder,
Expand All @@ -134,18 +165,20 @@
import matplotlib.pyplot as plt

scoring = "neg_mean_absolute_percentage_error"
dropped_result = cross_validate(hist_dropped, X, y, cv=3, scoring=scoring)
one_hot_result = cross_validate(hist_one_hot, X, y, cv=3, scoring=scoring)
ordinal_result = cross_validate(hist_ordinal, X, y, cv=3, scoring=scoring)
native_result = cross_validate(hist_native, X, y, cv=3, scoring=scoring)
n_cv_folds = 3

dropped_result = cross_validate(hist_dropped, X, y, cv=n_cv_folds, scoring=scoring)
one_hot_result = cross_validate(hist_one_hot, X, y, cv=n_cv_folds, scoring=scoring)
ordinal_result = cross_validate(hist_ordinal, X, y, cv=n_cv_folds, scoring=scoring)
native_result = cross_validate(hist_native, X, y, cv=n_cv_folds, scoring=scoring)


def plot_results(figure_title):
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 8))

plot_info = [
("fit_time", "Fit times (s)", ax1, None),
("test_score", "Mean Absolute Percentage Error", ax2, (0, 0.20)),
("test_score", "Mean Absolute Percentage Error", ax2, None),
]

x, width = np.arange(4), 0.9
Expand All @@ -156,11 +189,15 @@ def plot_results(figure_title):
ordinal_result[key],
native_result[key],
]

mape_cv_mean = [np.mean(np.abs(item)) for item in items]
mape_cv_std = [np.std(item) for item in items]

ax.bar(
x,
[np.mean(np.abs(item)) for item in items],
width,
yerr=[np.std(item) for item in items],
x=x,
height=mape_cv_mean,
width=width,
yerr=mape_cv_std,
color=["C0", "C1", "C2", "C3"],
)
ax.set(
Expand All @@ -173,7 +210,7 @@ def plot_results(figure_title):
fig.suptitle(figure_title)


plot_results("Gradient Boosting on Adult Census")
plot_results("Gradient Boosting on Ames Housing")

# %%
# We see that the model with one-hot-encoded data is by far the slowest. This
Expand Down Expand Up @@ -219,12 +256,12 @@ def plot_results(figure_title):
histgradientboostingregressor__max_iter=15,
)

dropped_result = cross_validate(hist_dropped, X, y, cv=3, scoring=scoring)
one_hot_result = cross_validate(hist_one_hot, X, y, cv=3, scoring=scoring)
ordinal_result = cross_validate(hist_ordinal, X, y, cv=3, scoring=scoring)
native_result = cross_validate(hist_native, X, y, cv=3, scoring=scoring)
dropped_result = cross_validate(hist_dropped, X, y, cv=n_cv_folds, scoring=scoring)
one_hot_result = cross_validate(hist_one_hot, X, y, cv=n_cv_folds, scoring=scoring)
ordinal_result = cross_validate(hist_ordinal, X, y, cv=n_cv_folds, scoring=scoring)
native_result = cross_validate(hist_native, X, y, cv=n_cv_folds, scoring=scoring)

plot_results("Gradient Boosting on Adult Census (few and small trees)")
plot_results("Gradient Boosting on Ames Housing (few and small trees)")

plt.show()

Expand Down