Skip to content

Commit fb301a9

Browse files
committed
Fixed bug in tf.train.ClusterSpec constructor.
Creating a `tf.train.ClusterSpec` from another ClusterSpec was broken, which in turn broke creating a `tf.train.Server` from a ClusterSpec. Fixes tensorflow#1961. Change: 119954117
1 parent 35cd6a3 commit fb301a9

File tree

2 files changed

+24
-1
lines changed

2 files changed

+24
-1
lines changed

tensorflow/python/training/server_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,7 @@ def __init__(self, cluster):
238238
elif isinstance(cluster, ClusterSpec):
239239
self._cluster_def = tensorflow_server_pb2.ClusterDef()
240240
self._cluster_def.MergeFrom(cluster.as_cluster_def())
241+
self._cluster_spec = {}
241242
for job_def in self._cluster_def.job:
242243
self._cluster_spec[job_def.name] = [t for t in job_def.tasks.values()]
243244
else:
@@ -306,4 +307,3 @@ def _make_cluster_def(self):
306307
raise TypeError(
307308
"Task address %r must be bytes or unicode" % task_address)
308309
job_def.tasks[i] = task_address
309-

tensorflow/python/training/server_lib_test.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,29 @@ def testTwoJobs(self):
146146
cluster_spec = tf.train.ClusterSpec(cluster_def)
147147
self.assertProtoEquals(cluster_def, cluster_spec.as_cluster_def())
148148

149+
def testClusterSpec(self):
150+
cluster_spec = tf.train.ClusterSpec(
151+
{"ps": ["ps0:2222", "ps1:2222"],
152+
"worker": ["worker0:2222", "worker1:2222", "worker2:2222"]})
153+
154+
expected_proto = """
155+
job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
156+
tasks { key: 1 value: 'ps1:2222' } }
157+
job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
158+
tasks { key: 1 value: 'worker1:2222' }
159+
tasks { key: 2 value: 'worker2:2222' } }
160+
"""
161+
162+
self.assertProtoEquals(expected_proto, cluster_spec.as_cluster_def())
163+
self.assertProtoEquals(
164+
expected_proto, tf.train.ClusterSpec(cluster_spec).as_cluster_def())
165+
self.assertProtoEquals(
166+
expected_proto,
167+
tf.train.ClusterSpec(cluster_spec.as_cluster_def()).as_cluster_def())
168+
self.assertProtoEquals(
169+
expected_proto,
170+
tf.train.ClusterSpec(cluster_spec.as_dict()).as_cluster_def())
171+
149172

150173
if __name__ == "__main__":
151174
tf.test.main()

0 commit comments

Comments
 (0)