Skip to content

Commit a8a95bf

Browse files
Add distribution support for incrementing the global step.
Don't require Dataset as input to eval and predict even when using a DistributionStrategy. PiperOrigin-RevId: 191017191
1 parent ec70d93 commit a8a95bf

File tree

7 files changed

+40
-20
lines changed

7 files changed

+40
-20
lines changed

tensorflow/python/estimator/canned/baseline_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@
4242
from tensorflow.python.ops import data_flow_ops
4343
from tensorflow.python.ops import math_ops
4444
from tensorflow.python.ops import parsing_ops
45-
from tensorflow.python.ops import state_ops
4645
from tensorflow.python.ops import variable_scope
4746
from tensorflow.python.ops import variables
4847
from tensorflow.python.platform import gfile
4948
from tensorflow.python.platform import test
5049
from tensorflow.python.summary.writer import writer_cache
5150
from tensorflow.python.training import checkpoint_utils
51+
from tensorflow.python.training import distribute as distribute_lib
5252
from tensorflow.python.training import input as input_lib
5353
from tensorflow.python.training import optimizer
5454
from tensorflow.python.training import queue_runner
@@ -482,15 +482,15 @@ def _minimize(loss, global_step=None, var_list=None):
482482
self.assertEquals(0, loss.shape.ndims)
483483
if expected_loss is None:
484484
if global_step is not None:
485-
return state_ops.assign_add(global_step, 1).op
485+
return distribute_lib.increment_var(global_step)
486486
return control_flow_ops.no_op()
487487
assert_loss = assert_close(
488488
math_ops.to_float(expected_loss, name='expected'),
489489
loss,
490490
name='assert_loss')
491491
with ops.control_dependencies((assert_loss,)):
492492
if global_step is not None:
493-
return state_ops.assign_add(global_step, 1).op
493+
return distribute_lib.increment_var(global_step)
494494
return control_flow_ops.no_op()
495495

496496
mock_optimizer = test.mock.NonCallableMock(
@@ -685,13 +685,13 @@ def _minimize(loss, global_step):
685685
# Verify loss. We can't check the value directly, so we add an assert op.
686686
self.assertEquals(0, loss.shape.ndims)
687687
if expected_loss is None:
688-
return state_ops.assign_add(global_step, 1).op
688+
return distribute_lib.increment_var(global_step)
689689
assert_loss = assert_close(
690690
math_ops.to_float(expected_loss, name='expected'),
691691
loss,
692692
name='assert_loss')
693693
with ops.control_dependencies((assert_loss,)):
694-
return state_ops.assign_add(global_step, 1).op
694+
return distribute_lib.increment_var(global_step)
695695

696696
mock_optimizer = test.mock.NonCallableMock(
697697
spec=optimizer.Optimizer,

tensorflow/python/estimator/canned/boosted_trees.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
from tensorflow.python.ops import gradients_impl
3333
from tensorflow.python.ops import lookup_ops
3434
from tensorflow.python.ops import math_ops
35-
from tensorflow.python.ops import state_ops
3635
from tensorflow.python.ops import variable_scope
3736
from tensorflow.python.ops.losses import losses
3837
from tensorflow.python.summary import summary
38+
from tensorflow.python.training import distribute as distribute_lib
3939
from tensorflow.python.training import session_run_hook
4040
from tensorflow.python.training import training_util
4141
from tensorflow.python.util.tf_export import tf_export
@@ -425,7 +425,7 @@ def grow_tree_from_stats_summaries(stats_summary_list):
425425
return grow_op
426426

427427
if train_in_memory and is_single_machine:
428-
train_op.append(state_ops.assign_add(global_step, 1))
428+
train_op.append(distribute_lib.increment_var(global_step))
429429
train_op.append(grow_tree_from_stats_summaries(stats_summary_list))
430430
else:
431431
summary_accumulator = data_flow_ops.ConditionalAccumulator(
@@ -445,7 +445,7 @@ def grow_tree_from_accumulated_summaries_fn():
445445
return grow_op
446446

447447
with ops.control_dependencies([apply_grad]):
448-
train_op.append(state_ops.assign_add(global_step, 1))
448+
train_op.append(distribute_lib.increment_var(global_step))
449449
if config.is_chief:
450450
train_op.append(
451451
control_flow_ops.cond(

tensorflow/python/estimator/canned/dnn_linear_combined.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,10 @@
3131
from tensorflow.python.ops import control_flow_ops
3232
from tensorflow.python.ops import nn
3333
from tensorflow.python.ops import partitioned_variables
34-
from tensorflow.python.ops import state_ops
3534
from tensorflow.python.ops import variable_scope
3635
from tensorflow.python.ops.losses import losses
3736
from tensorflow.python.summary import summary
37+
from tensorflow.python.training import distribute as distribute_lib
3838
from tensorflow.python.training import sync_replicas_optimizer
3939
from tensorflow.python.training import training_util
4040
from tensorflow.python.util.tf_export import tf_export
@@ -215,8 +215,7 @@ def _train_op_fn(loss):
215215

216216
train_op = control_flow_ops.group(*train_ops)
217217
with ops.control_dependencies([train_op]):
218-
with ops.colocate_with(global_step):
219-
return state_ops.assign_add(global_step, 1)
218+
return distribute_lib.increment_var(global_step)
220219

221220
return head.create_estimator_spec(
222221
features=features,

tensorflow/python/estimator/canned/dnn_testing_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,13 @@
4444
from tensorflow.python.ops import math_ops
4545
from tensorflow.python.ops import nn
4646
from tensorflow.python.ops import partitioned_variables
47-
from tensorflow.python.ops import state_ops
4847
from tensorflow.python.ops import variable_scope
4948
from tensorflow.python.ops import variables as variables_lib
5049
from tensorflow.python.platform import test
5150
from tensorflow.python.summary import summary as summary_lib
5251
from tensorflow.python.summary.writer import writer_cache
5352
from tensorflow.python.training import checkpoint_utils
53+
from tensorflow.python.training import distribute as distribute_lib
5454
from tensorflow.python.training import gradient_descent
5555
from tensorflow.python.training import monitored_session
5656
from tensorflow.python.training import optimizer as optimizer_lib
@@ -196,15 +196,15 @@ def _minimize(loss, global_step=None, var_list=None):
196196
testcase.assertEquals(0, loss.shape.ndims)
197197
if expected_loss is None:
198198
if global_step is not None:
199-
return state_ops.assign_add(global_step, 1).op
199+
return distribute_lib.increment_var(global_step)
200200
return control_flow_ops.no_op()
201201
assert_loss = assert_close(
202202
math_ops.to_float(expected_loss, name='expected'),
203203
loss,
204204
name='assert_loss')
205205
with ops.control_dependencies((assert_loss,)):
206206
if global_step is not None:
207-
return state_ops.assign_add(global_step, 1).op
207+
return distribute_lib.increment_var(global_step)
208208
return control_flow_ops.no_op()
209209

210210
optimizer_mock = test.mock.NonCallableMagicMock(

tensorflow/python/estimator/canned/linear_testing_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@
4747
from tensorflow.python.ops import math_ops
4848
from tensorflow.python.ops import parsing_ops
4949
from tensorflow.python.ops import partitioned_variables
50-
from tensorflow.python.ops import state_ops
5150
from tensorflow.python.ops import variable_scope
5251
from tensorflow.python.ops import variables as variables_lib
5352
from tensorflow.python.platform import gfile
5453
from tensorflow.python.platform import test
5554
from tensorflow.python.summary.writer import writer_cache
5655
from tensorflow.python.training import checkpoint_utils
56+
from tensorflow.python.training import distribute as distribute_lib
5757
from tensorflow.python.training import gradient_descent
5858
from tensorflow.python.training import input as input_lib
5959
from tensorflow.python.training import optimizer as optimizer_lib
@@ -682,15 +682,15 @@ def _minimize(loss, global_step=None, var_list=None):
682682
self.assertEquals(0, loss.shape.ndims)
683683
if expected_loss is None:
684684
if global_step is not None:
685-
return state_ops.assign_add(global_step, 1).op
685+
return distribute_lib.increment_var(global_step)
686686
return control_flow_ops.no_op()
687687
assert_loss = assert_close(
688688
math_ops.to_float(expected_loss, name='expected'),
689689
loss,
690690
name='assert_loss')
691691
with ops.control_dependencies((assert_loss,)):
692692
if global_step is not None:
693-
return state_ops.assign_add(global_step, 1).op
693+
return distribute_lib.increment_var(global_step)
694694
return control_flow_ops.no_op()
695695

696696
mock_optimizer = test.mock.NonCallableMock(
@@ -905,13 +905,13 @@ def _minimize(loss, global_step):
905905
# Verify loss. We can't check the value directly, so we add an assert op.
906906
self.assertEquals(0, loss.shape.ndims)
907907
if expected_loss is None:
908-
return state_ops.assign_add(global_step, 1).op
908+
return distribute_lib.increment_var(global_step)
909909
assert_loss = assert_close(
910910
math_ops.to_float(expected_loss, name='expected'),
911911
loss,
912912
name='assert_loss')
913913
with ops.control_dependencies((assert_loss,)):
914-
return state_ops.assign_add(global_step, 1).op
914+
return distribute_lib.increment_var(global_step)
915915

916916
mock_optimizer = test.mock.NonCallableMock(
917917
spec=optimizer_lib.Optimizer,

tensorflow/python/estimator/estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,8 @@ def _get_features_and_labels_from_input_fn(self, input_fn, mode):
693693
# using any input is alright in that case. There is also a
694694
# has_dataset_or_queue_runner function that we may want to extend and use.
695695
if (self._distribution is not None and
696-
not isinstance(result, dataset_ops.Dataset)):
696+
not isinstance(result, dataset_ops.Dataset) and
697+
mode == model_fn_lib.ModeKeys.TRAIN):
697698
raise ValueError('input_fn() must return a tf.data.Dataset when using a '
698699
'DistributionStrategy.')
699700
input_hooks = []

tensorflow/python/training/distribute.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from tensorflow.python.framework import ops
2424
from tensorflow.python.ops import array_ops
2525
from tensorflow.python.ops import control_flow_ops
26+
from tensorflow.python.ops import resource_variable_ops
27+
from tensorflow.python.ops import state_ops
2628
from tensorflow.python.ops import variable_scope
2729
from tensorflow.python.ops.losses import losses_impl
2830
from tensorflow.python.training import device_util
@@ -1166,6 +1168,24 @@ def _worker_device_index(self):
11661168
raise RuntimeError("worker_device_index() method unsupported by "
11671169
"_DefaultDistributionStrategy.")
11681170

1171+
# ------------------------------------------------------------------------------
1172+
# Common operations
1173+
1174+
1175+
def increment_var(v, amount=1):
1176+
"""`v += amount`, distributed-aware version."""
1177+
def update(vu):
1178+
if isinstance(vu, resource_variable_ops.ResourceVariable):
1179+
return vu.assign_add(amount, read_value=False)
1180+
else:
1181+
return state_ops.assign_add(vu, amount)
1182+
1183+
def merge_fn(dist, vm):
1184+
return dist.group(dist.update(vm, update))
1185+
1186+
tower_context = get_tower_context()
1187+
return tower_context.merge_call(merge_fn, v)
1188+
11691189

11701190
# ------------------------------------------------------------------------------
11711191
# Singletons

0 commit comments

Comments
 (0)