Skip to content

Commit 28dec7f

Browse files
guptapriyatensorflower-gardener
authored andcommitted
Add basic serialization support to DistributedVariable (by using the underlying primary variable's serialization). Also, throw an exception when trying to de-serialize as we haven't implemented that yet.
PiperOrigin-RevId: 191022884
1 parent 28b95b3 commit 28dec7f

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

tensorflow/contrib/distribute/python/values.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ def shape(self):
216216
def get_shape(self):
217217
return self._primary_var.get_shape()
218218

219+
def to_proto(self, export_scope=None):
220+
return self._primary_var.to_proto(export_scope=export_scope)
221+
219222
@property
220223
def op(self):
221224
# We want cross-tower code that does some var.op.X calls

tensorflow/python/training/distribute.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,3 +1194,20 @@ def merge_fn(dist, vm):
11941194
_default_tower_context = TowerContext(
11951195
_default_distribution_strategy, tower_id=0)
11961196
_default_tower_mode = _DefaultTowerThreadMode()
1197+
1198+
1199+
# ------------------------------------------------------------------------------
1200+
# We haven't yet implemented deserialization for DistributedVariables.
1201+
# So here we catch any attempts to deserialize variables
1202+
# when using distribution strategies.
1203+
# pylint: disable=protected-access
1204+
def _from_proto_fn(v, import_scope=None):
1205+
if has_distribution_strategy():
1206+
raise NotImplementedError(
1207+
"Deserialization of variables is not yet supported when using"
1208+
"distributed strategies.")
1209+
else:
1210+
resource_variable_ops._from_proto_fn(v, import_scope=import_scope)
1211+
1212+
resource_variable_ops._from_proto_fn = _from_proto_fn
1213+
# pylint: enable=protected-access

0 commit comments

Comments
 (0)