Skip to content

Commit db77881

Browse files
authored
Add export savedmodel to wide_deep (tensorflow#4041)
1 parent be7da42 commit db77881

File tree

5 files changed

+67
-26
lines changed

5 files changed

+67
-26
lines changed

official/mnist/mnist.py

-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,6 @@ def __init__(self):
246246
super(MNISTArgParser, self).__init__(parents=[
247247
parsers.BaseParser(),
248248
parsers.ImageModelParser(),
249-
parsers.ExportParser(),
250249
])
251250

252251
self.set_defaults(

official/resnet/resnet_run_loop.py

-1
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,6 @@ def __init__(self, resnet_size_choices=None):
465465
parsers.BaseParser(),
466466
parsers.PerformanceParser(),
467467
parsers.ImageModelParser(),
468-
parsers.ExportParser(),
469468
parsers.BenchmarkParser(),
470469
])
471470

official/utils/arg_parsers/parsers.py

+11-24
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,13 @@ class BaseParser(argparse.ArgumentParser):
104104
batch_size: Create a flag to specify the batch size.
105105
multi_gpu: Create a flag to allow the use of all available GPUs.
106106
hooks: Create a flag to specify hooks for logging.
107+
export_dir: Create a flag to specify where a SavedModel should be exported.
107108
"""
108109

109110
def __init__(self, add_help=False, data_dir=True, model_dir=True,
110111
train_epochs=True, epochs_between_evals=True,
111112
stop_threshold=True, batch_size=True, multi_gpu=True,
112-
hooks=True):
113+
hooks=True, export_dir=True):
113114
super(BaseParser, self).__init__(add_help=add_help)
114115

115116
if data_dir:
@@ -176,6 +177,15 @@ def __init__(self, add_help=False, data_dir=True, model_dir=True,
176177
metavar="<HK>"
177178
)
178179

180+
if export_dir:
181+
self.add_argument(
182+
"--export_dir", "-ed",
183+
help="[default: %(default)s] If set, a SavedModel serialization of "
184+
"the model will be exported to this directory at the end of "
185+
"training. See the README for more details and relevant links.",
186+
metavar="<ED>"
187+
)
188+
179189

180190
class PerformanceParser(argparse.ArgumentParser):
181191
"""Default parser for specifying performance tuning arguments.
@@ -292,29 +302,6 @@ def __init__(self, add_help=False, data_format=True):
292302
)
293303

294304

295-
class ExportParser(argparse.ArgumentParser):
296-
"""Parsing options for exporting saved models or other graph defs.
297-
298-
This is a separate parser for now, but should be made part of BaseParser
299-
once all models are brought up to speed.
300-
301-
Args:
302-
add_help: Create the "--help" flag. False if class instance is a parent.
303-
export_dir: Create a flag to specify where a SavedModel should be exported.
304-
"""
305-
306-
def __init__(self, add_help=False, export_dir=True):
307-
super(ExportParser, self).__init__(add_help=add_help)
308-
if export_dir:
309-
self.add_argument(
310-
"--export_dir", "-ed",
311-
help="[default: %(default)s] If set, a SavedModel serialization of "
312-
"the model will be exported to this directory at the end of "
313-
"training. See the README for more details and relevant links.",
314-
metavar="<ED>"
315-
)
316-
317-
318305
class BenchmarkParser(argparse.ArgumentParser):
319306
"""Default parser for benchmark logging.
320307

official/wide_deep/README.md

+31
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,37 @@ Run TensorBoard to inspect the details about the graph and training progression.
4747
tensorboard --logdir=/tmp/census_model
4848
```
4949

50+
## Inference with SavedModel
51+
You can export the model into Tensorflow [SavedModel](https://www.tensorflow.org/programmers_guide/saved_model) format by using the argument `--export_dir`:
52+
53+
```
54+
python wide_deep.py --export_dir /tmp/wide_deep_saved_model
55+
```
56+
57+
After the model finishes training, use [`saved_model_cli`](https://www.tensorflow.org/programmers_guide/saved_model#cli_to_inspect_and_execute_savedmodel) to inspect and execute the SavedModel.
58+
59+
Try the following commands to inspect the SavedModel:
60+
61+
**Replace `${TIMESTAMP}` with the folder produced (e.g. 1524249124)**
62+
```
63+
# List possible tag_sets. Only one metagraph is saved, so there will be one option.
64+
saved_model_cli show --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/
65+
66+
# Show SignatureDefs for tag_set=serve. SignatureDefs define the outputs to show.
67+
saved_model_cli show --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/ \
68+
--tag_set serve --all
69+
```
70+
71+
### Inference
72+
Let's use the model to predict the income group of two examples:
73+
```
74+
saved_model_cli run --dir /tmp/wide_deep_saved_model/${TIMESTAMP}/ \
75+
--tag_set serve --signature_def="predict" \
76+
--input_examples='examples=[{"age":[46.], "education_num":[10.], "capital_gain":[7688.], "capital_loss":[0.], "hours_per_week":[38.]}, {"age":[24.], "education_num":[13.], "capital_gain":[0.], "capital_loss":[0.], "hours_per_week":[50.]}]'
77+
```
78+
79+
This will print out the predicted classes and class probabilities. Class 0 is the <=50k group and 1 is the >50k group.
80+
5081
## Additional Links
5182

5283
If you are interested in distributed training, take a look at [Distributed TensorFlow](https://www.tensorflow.org/deploy/distributed).

official/wide_deep/wide_deep.py

+25
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,27 @@ def parse_csv(value):
175175
return dataset
176176

177177

178+
def export_model(model, model_type, export_dir):
179+
"""Export to SavedModel format.
180+
181+
Args:
182+
model: Estimator object
183+
model_type: string indicating model type. "wide", "deep" or "wide_deep"
184+
export_dir: directory to export the model.
185+
"""
186+
wide_columns, deep_columns = build_model_columns()
187+
if model_type == 'wide':
188+
columns = wide_columns
189+
elif model_type == 'deep':
190+
columns = deep_columns
191+
else:
192+
columns = wide_columns + deep_columns
193+
feature_spec = tf.feature_column.make_parse_example_spec(columns)
194+
example_input_fn = (
195+
tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec))
196+
model.export_savedmodel(export_dir, example_input_fn)
197+
198+
178199
def main(argv):
179200
parser = WideDeepArgParser()
180201
flags = parser.parse_args(args=argv[1:])
@@ -216,6 +237,10 @@ def eval_input_fn():
216237
flags.stop_threshold, results['accuracy']):
217238
break
218239

240+
# Export the model
241+
if flags.export_dir is not None:
242+
export_model(model, flags.model_type, flags.export_dir)
243+
219244

220245
class WideDeepArgParser(argparse.ArgumentParser):
221246
"""Argument parser for running the wide deep model."""

0 commit comments

Comments
 (0)