Skip to content

Commit 3365e14

Browse files
committed
fix examples
1 parent dcf6f7d commit 3365e14

7 files changed

+51
-23
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ _setup_ext.txt
1818
coverage.html/*
1919
_cache/*
2020
_deps/*
21+
simages/*
2122
.vs/*
2223
*.dir/*
2324
Release/*

_doc/examples/plot_digitize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,9 @@
8484

8585
onx = to_onnx(tree, x.reshape((-1, 1)), target_opset=15)
8686

87-
sess = InferenceSession(onx.SerializeToString())
87+
sess = InferenceSession(
88+
onx.SerializeToString(), providers=["CPUExecutionProvider"]
89+
)
8890

8991
ti = measure_time(
9092
"sess.run(None, {'X': x})",

_doc/examples/plot_piecewise_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ def graph(X, Y, model):
8989
#
9090

9191

92-
ax = seaborn.scatterplot("x1", "x2", "bucket", data=df, palette="Set1", s=400)
92+
ax = seaborn.scatterplot(x="x1", y="x2", hue="bucket", data=df, palette="Set1", s=400)
9393
seaborn.scatterplot(
94-
"x1", "x2", "label", data=df, palette="Set1", marker="o", ax=ax, s=100
94+
x="x1", y="x2", hue="label", data=df, palette="Set1", marker="o", ax=ax, s=100
9595
)
9696
ax.set_title("buckets")
9797

_doc/examples/plot_piecewise_linear_regression_criterion.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494

9595

9696
model2 = DecisionTreeRegressor(
97-
min_samples_leaf=100, criterion=SimpleRegressorCriterion(X_train)
97+
min_samples_leaf=100, criterion=SimpleRegressorCriterion(1, X_train.shape[0])
9898
)
9999
model2.fit(X_train, y_train)
100100

@@ -216,7 +216,7 @@
216216

217217

218218
model3 = DecisionTreeRegressor(
219-
min_samples_leaf=100, criterion=SimpleRegressorCriterionFast(X_train)
219+
min_samples_leaf=100, criterion=SimpleRegressorCriterionFast(1, X_train.shape[0])
220220
)
221221
model3.fit(X_train, y_train)
222222
pred = model3.predict(X_test)
@@ -282,7 +282,7 @@
282282

283283

284284
model3 = DecisionTreeRegressor(
285-
min_samples_leaf=100, criterion=SimpleRegressorCriterionFast(X_train3)
285+
min_samples_leaf=100, criterion=SimpleRegressorCriterionFast(1, X_train3.shape[0])
286286
)
287287
measure_time("model3.fit(X_train3, y_train3)", globals())
288288

@@ -309,62 +309,71 @@
309309
# into buckets and the prediction function of the decision tree which now
310310
# needs to return a dot product.
311311

312-
313-
piece = PiecewiseTreeRegressor(criterion="mselin", min_samples_leaf=100)
314-
piece.fit(X_train, y_train)
312+
fixed = False
313+
if fixed:
314+
# It does not work yet.
315+
piece = PiecewiseTreeRegressor(criterion="mselin", min_samples_leaf=100)
316+
piece.fit(X_train, y_train)
315317

316318

317319
#################################
318320
#
319321

320322

321-
pred = piece.predict(X_test)
322-
pred[:5]
323+
if fixed:
324+
pred = piece.predict(X_test)
325+
pred[:5]
323326

324327

325328
#################################
326329
#
327330

328331

329-
fig, ax = plt.subplots(1, 1)
330-
ax.plot(X_test[:, 0], y_test, ".", label="data")
331-
ax.plot(X_test[:, 0], pred, ".", label="predictions")
332-
ax.set_title("DecisionTreeRegressor\nwith criterion adapted to linear regression")
333-
ax.legend()
332+
if fixed:
333+
fig, ax = plt.subplots(1, 1)
334+
ax.plot(X_test[:, 0], y_test, ".", label="data")
335+
ax.plot(X_test[:, 0], pred, ".", label="predictions")
336+
ax.set_title("DecisionTreeRegressor\nwith criterion adapted to linear regression")
337+
ax.legend()
334338

335339
#################################
336340
# The coefficients for the linear regressions are kept into the following attribute:
337341

338342

339-
piece.betas_
343+
if fixed:
344+
piece.betas_
340345

341346

342347
#################################
343348
# Mapped to the following leaves:
344349

345350

346-
piece.leaves_index_, piece.leaves_mapping_
351+
if fixed:
352+
piece.leaves_index_, piece.leaves_mapping_
347353

348354

349355
#################################
350356
# We can get the leave each observation falls into:
351357

352358

353-
piece.predict_leaves(X_test)[:5]
359+
if fixed:
360+
piece.predict_leaves(X_test)[:5]
354361

355362

356363
#################################
357364
# The training is quite slow as it is training many
358365
# linear regressions each time a split is evaluated.
359366

360367

361-
measure_time("piece.fit(X_train, y_train)", globals())
368+
if fixed:
369+
measure_time("piece.fit(X_train, y_train)", globals())
362370

363371

364372
#################################
365373
#
366374

367-
measure_time("piece.fit(X_train3, y_train3)", globals())
375+
if fixed:
376+
measure_time("piece.fit(X_train3, y_train3)", globals())
368377

369378

370379
#################################

_doc/examples/plot_regression_confidence_interval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@
216216

217217

218218
df_train = pandas.DataFrame(dict(X=X_train.ravel(), y=y_train))
219-
g = sns.jointplot("X", "y", data=df_train, kind="reg", color="m", height=7)
219+
g = sns.jointplot(x="X", y="y", data=df_train, kind="reg", color="m", height=7)
220220
g.ax_joint.plot(X_test, y_test, "ro")
221221

222222

_doc/examples/plot_search_images_torch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@
6868
if not os.path.exists("simages/category"):
6969
os.makedirs("simages/category")
7070

71-
files = unzip_files("data/dog-cat-pixabay.zip", where_to="simages/category")
71+
url = "https://github.com/sdpython/mlinsights/raw/ref/_doc/examples/data/dog-cat-pixabay.zip"
72+
files = unzip_files(url, where_to="simages/category")
7273
len(files), files[0]
7374

7475
##########################################

mlinsights/ext_test_case.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import warnings
55
import zipfile
66
from io import BytesIO
7+
from urllib.request import urlopen
78
from argparse import ArgumentParser
89
from contextlib import redirect_stderr, redirect_stdout
910
from io import StringIO
@@ -442,6 +443,20 @@ def unzip_files(
442443
:param verbose: display file names
443444
:return: list of unzipped files
444445
"""
446+
if zipf.startswith("https:"):
447+
filename = zipf.split("/")[-1]
448+
dest_zip = os.path.join(where_to, filename)
449+
if not os.path.exists(dest_zip):
450+
if verbose:
451+
print(f"[unzip_files] downloads into {dest_zip!r} from {zipf!r}")
452+
with urlopen(zipf, timeout=10) as u:
453+
content = u.read()
454+
with open(dest_zip, "wb") as f:
455+
f.write(content)
456+
elif verbose:
457+
print(f"[unzip_files] already downloaded {dest_zip!r}")
458+
zipf = dest_zip
459+
445460
if isinstance(zipf, bytes):
446461
zipf = BytesIO(zipf)
447462

0 commit comments

Comments
 (0)