Skip to content

Commit b0a15f2

Browse files
Make the return value of read_var consistently a tensor instead of
sometimes a variable. PiperOrigin-RevId: 200231463
1 parent 3b4f416 commit b0a15f2

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

tensorflow/contrib/distribute/python/mirrored_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ def read_var(self, tower_local_var):
349349
if isinstance(tower_local_var, values.TowerLocalVariable):
350350
return math_ops.add_n(self.unwrap(tower_local_var))
351351
assert isinstance(tower_local_var, values.Mirrored)
352-
return tower_local_var.get()
352+
return array_ops.identity(tower_local_var.get())
353353

354354
def _fetch(self, val, destination, fn):
355355
"""Return a copy of `val` or `fn(val)` on `destination`."""

tensorflow/contrib/distribute/python/one_device_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
104104

105105
def read_var(self, tower_local_var):
106106
"""Read the aggregate value of a tower-local variable."""
107-
return tower_local_var
107+
return array_ops.identity(tower_local_var)
108108

109109
def _fetch(self, val, destination, fn):
110110
"""Return a copy of `val` or `fn(val)` on `destination`."""

tensorflow/python/training/distribute.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ def read_var(self, v):
652652
"""Reads the value of a variable.
653653
654654
Returns the aggregate value of a tower-local variable, or the
655-
(possibly read-only) value of any other variable.
655+
(read-only) value of any other variable.
656656
657657
Args:
658658
v: A variable allocated within the scope of this `DistributionStrategy`.
@@ -1217,7 +1217,7 @@ def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
12171217
return fn(*args, **kwargs)
12181218

12191219
def read_var(self, tower_local_var):
1220-
return tower_local_var
1220+
return array_ops.identity(tower_local_var)
12211221

12221222
def _fetch(self, var, destination, fn):
12231223
with ops.colocate_with(var):

0 commit comments

Comments
 (0)