Skip to content

Commit b0417c9

Browse files
authored
Add samples for the Cloud ML Engine (GoogleCloudPlatform#824)
Samples for using online prediction with JSON and TFRecord inputs.
1 parent 8123cdd commit b0417c9

File tree

4 files changed

+274
-0
lines changed

4 files changed

+274
-0
lines changed

ml_engine/online_prediction/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
https://cloud.google.com/ml-engine/docs/concepts/prediction-overview
+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
#!/bin/python
2+
# Copyright 2017 Google Inc. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Examples of using the Cloud ML Engine's online prediction service."""
17+
import argparse
18+
import base64
19+
import json
20+
21+
# [START import_libraries]
22+
import googleapiclient.discovery
23+
# [END import_libraries]
24+
import six
25+
26+
27+
# [START predict_json]
28+
def predict_json(project, model, instances, version=None):
29+
"""Send json data to a deployed model for prediction.
30+
31+
Args:
32+
project (str): project where the Cloud ML Engine Model is deployed.
33+
model (str): model name.
34+
instances ([Mapping[str: Any]]): Keys should be the names of Tensors
35+
your deployed model expects as inputs. Values should be datatypes
36+
convertible to Tensors, or (potentially nested) lists of datatypes
37+
convertible to tensors.
38+
version: str, version of the model to target.
39+
Returns:
40+
Mapping[str: any]: dictionary of prediction results defined by the
41+
model.
42+
"""
43+
# Create the ML Engine service object.
44+
# To authenticate set the environment variable
45+
# GOOGLE_APPLICATION_CREDENTIALS=<path_to_service_account_file>
46+
service = googleapiclient.discovery.build('ml', 'v1beta1')
47+
name = 'projects/{}/models/{}'.format(project, model)
48+
49+
if version is not None:
50+
name += '/versions/{}'.format(version)
51+
52+
response = service.projects().predict(
53+
name=name,
54+
body={'instances': instances}
55+
).execute()
56+
57+
if 'error' in response:
58+
raise RuntimeError(response['error'])
59+
60+
return response['predictions']
61+
# [END predict_json]
62+
63+
64+
# [START predict_tf_records]
65+
def predict_tf_records(project,
66+
model,
67+
example_bytes_list,
68+
version=None):
69+
"""Send protocol buffer data to a deployed model for prediction.
70+
71+
Args:
72+
project (str): project where the Cloud ML Engine Model is deployed.
73+
model (str): model name.
74+
example_bytes_list ([str]): A list of bytestrings representing
75+
serialized tf.train.Example protocol buffers. The contents of this
76+
protocol buffer will change depending on the signature of your
77+
deployed model.
78+
version: str, version of the model to target.
79+
Returns:
80+
Mapping[str: any]: dictionary of prediction results defined by the
81+
model.
82+
"""
83+
service = googleapiclient.discovery.build('ml', 'v1beta1')
84+
name = 'projects/{}/models/{}'.format(project, model)
85+
86+
if version is not None:
87+
name += '/versions/{}'.format(version)
88+
89+
response = service.projects().predict(
90+
name=name,
91+
body={'instances': [
92+
{'b64': base64.b64encode(example_bytes)}
93+
for example_bytes in example_bytes_list
94+
]}
95+
).execute()
96+
97+
if 'error' in response:
98+
raise RuntimeError(response['error'])
99+
100+
return response['predictions']
101+
# [END predict_tf_records]
102+
103+
104+
# [START census_to_example_bytes]
105+
def census_to_example_bytes(json_instance):
106+
"""Serialize a JSON example to the bytes of a tf.train.Example.
107+
This method is specific to the signature of the Census example.
108+
See: https://cloud.google.com/ml-engine/docs/concepts/prediction-overview
109+
for details.
110+
111+
Args:
112+
json_instance (Mapping[str: Any]): Keys should be the names of Tensors
113+
your deployed model expects to parse using it's tf.FeatureSpec.
114+
Values should be datatypes convertible to Tensors, or (potentially
115+
nested) lists of datatypes convertible to tensors.
116+
Returns:
117+
str: A string as a container for the serialized bytes of
118+
tf.train.Example protocol buffer.
119+
"""
120+
import tensorflow as tf
121+
feature_dict = {}
122+
for key, data in json_instance.iteritems():
123+
if isinstance(data, six.string_types):
124+
feature_dict[key] = tf.train.Feature(
125+
bytes_list=tf.train.BytesList(value=[str(data)]))
126+
elif isinstance(data, float):
127+
feature_dict[key] = tf.train.Feature(
128+
float_list=tf.train.FloatList(value=[data]))
129+
elif isinstance(data, int):
130+
feature_dict[key] = tf.train.Feature(
131+
int64_list=tf.train.Int64List(value=[data]))
132+
return tf.train.Example(
133+
features=tf.train.Features(
134+
feature=feature_dict
135+
)
136+
).SerializeToString()
137+
# [END census_to_example_bytes]
138+
139+
140+
def main(project, model, version=None, force_tfrecord=False):
141+
"""Send user input to the prediction service."""
142+
while True:
143+
try:
144+
user_input = json.loads(raw_input("Valid JSON >>>"))
145+
except KeyboardInterrupt:
146+
return
147+
148+
if not isinstance(user_input, list):
149+
user_input = [user_input]
150+
try:
151+
if force_tfrecord:
152+
example_bytes_list = [
153+
census_to_example_bytes(e)
154+
for e in user_input
155+
]
156+
result = predict_tf_records(
157+
project, model, example_bytes_list, version=version)
158+
else:
159+
result = predict_json(
160+
project, model, user_input, version=version)
161+
except RuntimeError as err:
162+
print(str(err))
163+
else:
164+
print(result)
165+
166+
167+
if __name__ == '__main__':
168+
parser = argparse.ArgumentParser()
169+
parser.add_argument(
170+
'--project',
171+
help='Project in which the model is deployed',
172+
type=str,
173+
required=True
174+
)
175+
parser.add_argument(
176+
'--model',
177+
help='Model name',
178+
type=str,
179+
required=True
180+
)
181+
parser.add_argument(
182+
'--version',
183+
help='Name of the version.',
184+
type=str
185+
)
186+
parser.add_argument(
187+
'--force-tfrecord',
188+
help='Send predictions as TFRecords rather than raw JSON',
189+
action='store_true',
190+
default=False
191+
)
192+
args = parser.parse_args()
193+
main(
194+
args.project,
195+
args.model,
196+
version=args.version,
197+
force_tfrecord=args.force_tfrecord
198+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2017 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for predict.py ."""
16+
17+
import base64
18+
19+
import pytest
20+
21+
import predict
22+
23+
24+
MODEL = 'census'
25+
VERSION = 'v1'
26+
TF_RECORDS_VERSION = 'v1tfrecord'
27+
PROJECT = 'python-docs-samples-tests'
28+
JSON = {
29+
'age': 25,
30+
'workclass': ' Private',
31+
'education': ' 11th',
32+
'education_num': 7,
33+
'marital_status': ' Never-married',
34+
'occupation': ' Machine-op-inspct',
35+
'relationship': ' Own-child',
36+
'race': ' Black',
37+
'gender': ' Male',
38+
'capital_gain': 0,
39+
'capital_loss': 0,
40+
'hours_per_week': 40,
41+
'native_country': ' United-States'
42+
}
43+
EXPECTED_OUTPUT = {
44+
u'probabilities': [0.9942260384559631, 0.005774002522230148],
45+
u'logits': [-5.148599147796631],
46+
u'classes': 0,
47+
u'logistic': [0.005774001590907574]
48+
}
49+
50+
51+
def test_predict_json():
52+
result = predict.predict_json(
53+
PROJECT, MODEL, [JSON, JSON], version=VERSION)
54+
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result
55+
56+
57+
def test_predict_json_error():
58+
with pytest.raises(RuntimeError):
59+
predict.predict_json(PROJECT, MODEL, [{"foo": "bar"}], version=VERSION)
60+
61+
62+
@pytest.mark.slow
63+
def test_census_example_to_bytes():
64+
b = predict.census_to_example_bytes(JSON)
65+
assert base64.b64encode(b) is not None
66+
67+
68+
@pytest.mark.slow
69+
@pytest.mark.xfail('Single placeholder inputs broken in service b/35778449')
70+
def test_predict_tfrecords():
71+
b = predict.census_to_example_bytes(JSON)
72+
result = predict.predict_tfrecords(
73+
PROJECT, MODEL, [b, b], version=TF_RECORDS_VERSION)
74+
assert [EXPECTED_OUTPUT, EXPECTED_OUTPUT] == result
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
tensorflow==1.0.0

0 commit comments

Comments
 (0)