Skip to content

Commit 85a8aa6

Browse files
committed
BENCH att no-interactions option to bench_hist_gradient_boosting_higgsboson
1 parent d7db6ac commit 85a8aa6

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

benchmarks/bench_hist_gradient_boosting_higgsboson.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
parser.add_argument("--max-bins", type=int, default=255)
2525
parser.add_argument("--no-predict", action="store_true", default=False)
2626
parser.add_argument("--cache-loc", type=str, default="/tmp")
27+
parser.add_argument("--no-interactions", type=bool, default=False)
2728
args = parser.parse_args()
2829

2930
HERE = os.path.dirname(__file__)
@@ -88,6 +89,11 @@ def predict(est, data_test, target_test):
8889
n_samples, n_features = data_train.shape
8990
print(f"Training set with {n_samples} records with {n_features} features.")
9091

92+
if args.no_interactions:
93+
interaction_cst = [[i] for i in range(n_features)]
94+
else:
95+
interaction_cst = None
96+
9197
est = HistGradientBoostingClassifier(
9298
loss="log_loss",
9399
learning_rate=lr,
@@ -97,6 +103,7 @@ def predict(est, data_test, target_test):
97103
early_stopping=False,
98104
random_state=0,
99105
verbose=1,
106+
interaction_cst=interaction_cst,
100107
)
101108
fit(est, data_train, target_train, "sklearn")
102109
predict(est, data_test, target_test)

0 commit comments

Comments
 (0)