Skip to content

Commit dc7821c

Browse files
Apply import_scope to asset and variable tensors during tf.saved_model.loader.load
This change explicitly declares import_scope as a kwarg for tf.saved_model.loader.load. Previously, tf.saved_model.loader.load implicitly accepted import_scope and passed it through to import_meta_graph through **saver_kwargs. PiperOrigin-RevId: 200249417
1 parent ba9422a commit dc7821c

File tree

5 files changed

+111
-8
lines changed

5 files changed

+111
-8
lines changed

tensorflow/python/saved_model/loader_impl.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,12 +79,14 @@ def _parse_saved_model(export_dir):
7979
constants.SAVED_MODEL_FILENAME_PB))
8080

8181

82-
def _get_asset_tensors(export_dir, meta_graph_def_to_load):
82+
def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
8383
"""Gets the asset tensors, if defined in the meta graph def to load.
8484
8585
Args:
8686
export_dir: Directory where the SavedModel is located.
8787
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
88+
import_scope: Optional `string` -- if specified, prepend this followed by
89+
'/' to all returned asset tensor names.
8890
8991
Returns:
9092
A dictionary of asset tensors, keyed by the name of the asset tensor. The
@@ -104,7 +106,10 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load):
104106
for asset_any_proto in assets_any_proto:
105107
asset_proto = meta_graph_pb2.AssetFileDef()
106108
asset_any_proto.Unpack(asset_proto)
107-
asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
109+
tensor_name = asset_proto.tensor_info.name
110+
if import_scope:
111+
tensor_name = "%s/%s" % (import_scope, tensor_name)
112+
asset_tensor_dict[tensor_name] = os.path.join(
108113
compat.as_bytes(assets_directory),
109114
compat.as_bytes(asset_proto.filename))
110115
return asset_tensor_dict
@@ -179,7 +184,7 @@ def maybe_saved_model_directory(export_dir):
179184

180185

181186
@tf_export("saved_model.loader.load")
182-
def load(sess, tags, export_dir, **saver_kwargs):
187+
def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
183188
"""Loads the model from a SavedModel as specified by tags.
184189
185190
Args:
@@ -189,6 +194,10 @@ def load(sess, tags, export_dir, **saver_kwargs):
189194
SavedModel `save()` API.
190195
export_dir: Directory in which the SavedModel protocol buffer and variables
191196
to be loaded are located.
197+
import_scope: Optional `string` -- if specified, prepend this string
198+
followed by '/' to all loaded tensor names. This scope is applied to
199+
tensor instances loaded into the passed session, but it is *not* written
200+
through to the static `MetaGraphDef` protocol buffer that is returned.
192201
**saver_kwargs: Optional keyword arguments passed through to Saver.
193202
194203
Returns:
@@ -216,7 +225,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
216225
)
217226

218227
# Build a saver by importing the meta graph def to load.
219-
saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
228+
saver = tf_saver.import_meta_graph(
229+
meta_graph_def_to_load, import_scope=import_scope, **saver_kwargs)
220230

221231
if saver:
222232
# Build the checkpoint path where the variables are located.
@@ -232,8 +242,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
232242
"checkpoints were restored.")
233243

234244
# Get asset tensors, if any.
235-
asset_tensors_dictionary = _get_asset_tensors(export_dir,
236-
meta_graph_def_to_load)
245+
asset_tensors_dictionary = _get_asset_tensors(
246+
export_dir, meta_graph_def_to_load, import_scope=import_scope)
237247

238248
main_op_tensor = (
239249
_get_main_op_tensor(meta_graph_def_to_load) or

tensorflow/python/saved_model/saved_model_test.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,6 +1197,59 @@ def _validate_custom_saver(tag_name, saver_name):
11971197
_validate_custom_saver("tag_1", "save_1/restore_all")
11981198
_validate_custom_saver("tag_2", "save_2/restore_all")
11991199

1200+
def testImportScope(self):
1201+
export_dir = self._get_export_dir("test_scoped_assets")
1202+
builder = saved_model_builder.SavedModelBuilder(export_dir)
1203+
1204+
# Build a SavedModel with a variable, an asset, and a constant tensor.
1205+
with self.test_session(graph=ops.Graph()) as sess:
1206+
self._init_and_validate_variable(sess, "v", 42)
1207+
asset_collection = self._build_asset_collection("foo.txt", "content_foo",
1208+
"asset_file_tensor")
1209+
constant_op.constant("constant value", name="constant_tensor_name")
1210+
builder.add_meta_graph_and_variables(
1211+
sess, ["tag_name"], assets_collection=asset_collection)
1212+
1213+
# Save the asset file path for later comparison.
1214+
asset_file_path = asset_collection[0].eval()
1215+
1216+
# Save the SavedModel to disk.
1217+
builder.save()
1218+
1219+
with self.test_session(graph=ops.Graph()) as sess:
1220+
# Restore the SavedModel under an import_scope in a new graph/session.
1221+
graph_proto = loader.load(
1222+
sess, ["tag_name"], export_dir, import_scope="scope_name")
1223+
1224+
# The loaded variable tensor should be scoped, but its contents should be
1225+
# unchanged.
1226+
self.assertEqual(
1227+
"scope_name/v:0",
1228+
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].name)
1229+
self.assertEqual(
1230+
42,
1231+
ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
1232+
1233+
# The loaded asset tensor should be scoped, but the asset file path and
1234+
# contents should be unchanged.
1235+
asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
1236+
self.assertEqual(1, len(asset_collection))
1237+
self.assertEqual(asset_file_path, asset_collection[0].eval())
1238+
self.assertEqual("scope_name/asset_file_tensor:0",
1239+
asset_collection[0].name)
1240+
# The static asset data inside graph_proto.collection_def should not be
1241+
# scoped.
1242+
self._validate_asset_collection(export_dir, graph_proto.collection_def,
1243+
"foo.txt", "content_foo",
1244+
"asset_file_tensor:0")
1245+
1246+
# The constant tensor should be scoped, but its contents should be
1247+
# unchanged.
1248+
self.assertEqual(
1249+
compat.as_bytes("constant value"),
1250+
ops.get_default_graph().get_tensor_by_name(
1251+
"scope_name/constant_tensor_name:0").eval())
1252+
12001253
def testClearDevices(self):
12011254
export_dir = self._get_export_dir("test_clear_devices")
12021255
builder = saved_model_builder.SavedModelBuilder(export_dir)

tensorflow/python/training/saver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1970,7 +1970,7 @@ def import_meta_graph(meta_graph_or_file, clear_devices=False,
19701970

19711971
return Saver(saver_def=meta_graph_def.saver_def, name=scope)
19721972
else:
1973-
if variables._all_saveable_objects(): # pylint: disable=protected-access
1973+
if variables._all_saveable_objects(scope=import_scope): # pylint: disable=protected-access
19741974
# Return the default saver instance for all graph variables.
19751975
return Saver()
19761976
else:

tensorflow/python/training/saver_test.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2339,6 +2339,46 @@ def testImportIntoNamescope(self):
23392339
10, size=[1, 10])
23402340
})
23412341

2342+
def testImportIntoNamescopeWithoutVariables(self):
2343+
# Save a simple graph that contains no variables into a checkpoint.
2344+
test_dir = self._get_test_dir("no_vars_graph")
2345+
filename = os.path.join(test_dir, "ckpt")
2346+
graph_1 = ops_lib.Graph()
2347+
with session.Session(graph=graph_1) as sess:
2348+
constant_op.constant([1, 2, 3], name="x")
2349+
constant_op.constant([1, 2, 3], name="y")
2350+
saver = saver_module.Saver(allow_empty=True)
2351+
saver.save(sess, filename)
2352+
2353+
# Create a fresh graph.
2354+
graph_2 = ops_lib.Graph()
2355+
with session.Session(graph=graph_2) as sess:
2356+
# Restore the above checkpoint under scope "subgraph_1".
2357+
new_saver_1 = saver_module.import_meta_graph(
2358+
filename + ".meta", graph=graph_2, import_scope="subgraph_1")
2359+
# There are no variables to restore, so import_meta_graph should not
2360+
# return a Saver.
2361+
self.assertIsNone(new_saver_1)
2362+
2363+
# Create a variable in graph_2 under scope "my_scope".
2364+
variables.Variable(array_ops.zeros([10]), name="my_scope/my_var")
2365+
sess.run(variables.global_variables_initializer())
2366+
# Restore the checkpoint into a different scope "subgraph_2".
2367+
new_saver_2 = saver_module.import_meta_graph(
2368+
filename + ".meta", graph=graph_2, import_scope="subgraph_2")
2369+
# Because the variable does not live in scope "subgraph_2",
2370+
# import_meta_graph should not attempt to restore the variable. So,
2371+
# import_meta_graph still won't return a Saver instance.
2372+
self.assertIsNone(new_saver_2)
2373+
2374+
# However, if we restore the checkpoint under scope "my_scope",
2375+
# import_meta_graph will detect the variable and return a Saver for
2376+
# restoring it. This should happen even when the variable does not
2377+
# originate from graph_1.
2378+
new_saver_3 = saver_module.import_meta_graph(
2379+
filename + ".meta", graph=graph_2, import_scope="my_scope")
2380+
self.assertIsInstance(new_saver_3, saver_module.Saver)
2381+
23422382
def testImportIntoImplicitNamescope(self):
23432383
# Test that we can import a meta graph into an implicit namescope.
23442384
test_dir = self._get_test_dir("import_into_namescope")

tensorflow/tools/api/golden/tensorflow.saved_model.loader.pbtxt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ path: "tensorflow.saved_model.loader"
22
tf_module {
33
member_method {
44
name: "load"
5-
argspec: "args=[\'sess\', \'tags\', \'export_dir\'], varargs=None, keywords=saver_kwargs, defaults=None"
5+
argspec: "args=[\'sess\', \'tags\', \'export_dir\', \'import_scope\'], varargs=None, keywords=saver_kwargs, defaults=[\'None\'], "
66
}
77
member_method {
88
name: "maybe_saved_model_directory"

0 commit comments

Comments
 (0)