Skip to content

Commit ec2e1dc

Browse files
committed
use updated client
1 parent 9df9474 commit ec2e1dc

File tree

2 files changed

+32
-21
lines changed

2 files changed

+32
-21
lines changed

tables/automl/automl_tables_predict.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@
2525
import os
2626

2727

28-
def predict(project_id, compute_region, model_display_name, inputs, feature_importance=None):
28+
def predict(
29+
project_id,
30+
compute_region,
31+
model_display_name,
32+
inputs,
33+
feature_importance=None,
34+
):
2935
"""Make a prediction."""
3036
# [START automl_tables_predict]
3137
# TODO(developer): Uncomment and set the following variables
@@ -38,13 +44,16 @@ def predict(project_id, compute_region, model_display_name, inputs, feature_impo
3844

3945
client = automl.TablesClient(project=project_id, region=compute_region)
4046

41-
params = {}
4247
if feature_importance:
43-
params = {"feature_importance": feature_importance}
44-
45-
response = client.predict(
46-
model_display_name=model_display_name, inputs=inputs, params=params
47-
)
48+
response = client.predict(
49+
model_display_name=model_display_name,
50+
inputs=inputs,
51+
feature_importance=True,
52+
)
53+
else:
54+
response = client.predict(
55+
model_display_name=model_display_name, inputs=inputs
56+
)
4857

4958
print("Prediction results:")
5059
for result in response.payload:
@@ -53,20 +62,21 @@ def predict(project_id, compute_region, model_display_name, inputs, feature_impo
5362
)
5463
print("Predicted class score: {}".format(result.tables.score))
5564

56-
# get features of top importance
57-
feat_list = [
58-
(column.feature_importance, column.column_display_name)
59-
for column in result.tables.tables_model_column_info
60-
]
61-
feat_list.sort(reverse=True)
62-
if len(feat_list) < 10:
63-
feat_to_show = len(feat_list)
64-
else:
65-
feat_to_show = 10
66-
67-
print("Features of top importance:")
68-
for feat in feat_list[:feat_to_show]:
69-
print(feat)
65+
if feature_importance:
66+
# get features of top importance
67+
feat_list = [
68+
(column.feature_importance, column.column_display_name)
69+
for column in result.tables.tables_model_column_info
70+
]
71+
feat_list.sort(reverse=True)
72+
if len(feat_list) < 10:
73+
feat_to_show = len(feat_list)
74+
else:
75+
feat_to_show = 10
76+
77+
print("Features of top importance:")
78+
for feat in feat_list[:feat_to_show]:
79+
print(feat)
7080

7181
# [END automl_tables_predict]
7282

tables/automl/predict_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def test_predict(capsys):
5252
out, _ = capsys.readouterr()
5353
assert "Predicted class name:" in out
5454
assert "Predicted class score:" in out
55+
assert "Features of top importance:" in out
5556

5657

5758
def ensure_model_online():

0 commit comments

Comments
 (0)