Skip to content

Commit ffae2ed

Browse files
elibixbyJon Wayne Parrott
authored and
Jon Wayne Parrott
committed
Fix ml_engine tests (GoogleCloudPlatform#850)
1 parent 87f7d24 commit ffae2ed

File tree

2 files changed

+16
-18
lines changed

2 files changed

+16
-18
lines changed

ml_engine/online_prediction/predict.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ def predict_json(project, model, instances, version=None):
6262

6363

6464
# [START predict_tf_records]
65-
def predict_tf_records(project,
66-
model,
67-
example_bytes_list,
68-
version=None):
65+
def predict_examples(project,
66+
model,
67+
example_bytes_list,
68+
version=None):
6969
"""Send protocol buffer data to a deployed model for prediction.
7070
7171
Args:
@@ -119,7 +119,7 @@ def census_to_example_bytes(json_instance):
119119
"""
120120
import tensorflow as tf
121121
feature_dict = {}
122-
for key, data in json_instance.iteritems():
122+
for key, data in six.iteritems(json_instance):
123123
if isinstance(data, six.string_types):
124124
feature_dict[key] = tf.train.Feature(
125125
bytes_list=tf.train.BytesList(value=[str(data)]))
@@ -153,7 +153,7 @@ def main(project, model, version=None, force_tfrecord=False):
153153
census_to_example_bytes(e)
154154
for e in user_input
155155
]
156-
result = predict_tf_records(
156+
result = predict_examples(
157157
project, model, example_bytes_list, version=version)
158158
else:
159159
result = predict_json(

ml_engine/online_prediction/predict_test.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222

2323

2424
MODEL = 'census'
25-
VERSION = 'v1'
26-
TF_RECORDS_VERSION = 'v1tfrecord'
25+
JSON_VERSION = 'v1json'
26+
EXAMPLES_VERSION = 'v1example'
2727
PROJECT = 'python-docs-samples-tests'
2828
JSON = {
2929
'age': 25,
@@ -41,22 +41,21 @@
4141
'native_country': ' United-States'
4242
}
4343
EXPECTED_OUTPUT = {
44-
u'probabilities': [0.9942260384559631, 0.005774002522230148],
45-
u'logits': [-5.148599147796631],
46-
u'classes': 0,
47-
u'logistic': [0.005774001590907574]
44+
u'confidence': 0.7760371565818787,
45+
u'predictions': u' <=50K'
4846
}
4947

5048

5149
def test_predict_json():
5250
result = predict.predict_json(
53-
PROJECT, MODEL, [JSON, JSON], version=VERSION)
51+
PROJECT, MODEL, [JSON, JSON], version=JSON_VERSION)
5452
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result
5553

5654

5755
def test_predict_json_error():
5856
with pytest.raises(RuntimeError):
59-
predict.predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION)
57+
predict.predict_json(
58+
PROJECT, MODEL, [{"foo": "bar"}], version=JSON_VERSION)
6059

6160

6261
@pytest.mark.slow
@@ -66,9 +65,8 @@ def test_census_example_to_bytes():
6665

6766

6867
@pytest.mark.slow
69-
@pytest.mark.xfail('Single placeholder inputs broken in service b/35778449')
70-
def test_predict_tfrecords():
68+
def test_predict_examples():
7169
b = predict.census_to_example_bytes(JSON)
72-
result = predict.predict_tfrecords(
73-
PROJECT, MODEL, [b, b], version=TF_RECORDS_VERSION)
70+
result = predict.predict_examples(
71+
PROJECT, MODEL, [b, b], version=EXAMPLES_VERSION)
7472
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result

0 commit comments

Comments
 (0)