File tree Expand file tree Collapse file tree 2 files changed +24
-1
lines changed
tensorflow/python/training Expand file tree Collapse file tree 2 files changed +24
-1
lines changed Original file line number Diff line number Diff line change @@ -238,6 +238,7 @@ def __init__(self, cluster):
238
238
elif isinstance (cluster , ClusterSpec ):
239
239
self ._cluster_def = tensorflow_server_pb2 .ClusterDef ()
240
240
self ._cluster_def .MergeFrom (cluster .as_cluster_def ())
241
+ self ._cluster_spec = {}
241
242
for job_def in self ._cluster_def .job :
242
243
self ._cluster_spec [job_def .name ] = [t for t in job_def .tasks .values ()]
243
244
else :
@@ -306,4 +307,3 @@ def _make_cluster_def(self):
306
307
raise TypeError (
307
308
"Task address %r must be bytes or unicode" % task_address )
308
309
job_def .tasks [i ] = task_address
309
-
Original file line number Diff line number Diff line change @@ -146,6 +146,29 @@ def testTwoJobs(self):
146
146
cluster_spec = tf .train .ClusterSpec (cluster_def )
147
147
self .assertProtoEquals (cluster_def , cluster_spec .as_cluster_def ())
148
148
149
+ def testClusterSpec (self ):
150
+ cluster_spec = tf .train .ClusterSpec (
151
+ {"ps" : ["ps0:2222" , "ps1:2222" ],
152
+ "worker" : ["worker0:2222" , "worker1:2222" , "worker2:2222" ]})
153
+
154
+ expected_proto = """
155
+ job { name: 'ps' tasks { key: 0 value: 'ps0:2222' }
156
+ tasks { key: 1 value: 'ps1:2222' } }
157
+ job { name: 'worker' tasks { key: 0 value: 'worker0:2222' }
158
+ tasks { key: 1 value: 'worker1:2222' }
159
+ tasks { key: 2 value: 'worker2:2222' } }
160
+ """
161
+
162
+ self .assertProtoEquals (expected_proto , cluster_spec .as_cluster_def ())
163
+ self .assertProtoEquals (
164
+ expected_proto , tf .train .ClusterSpec (cluster_spec ).as_cluster_def ())
165
+ self .assertProtoEquals (
166
+ expected_proto ,
167
+ tf .train .ClusterSpec (cluster_spec .as_cluster_def ()).as_cluster_def ())
168
+ self .assertProtoEquals (
169
+ expected_proto ,
170
+ tf .train .ClusterSpec (cluster_spec .as_dict ()).as_cluster_def ())
171
+
149
172
150
173
if __name__ == "__main__" :
151
174
tf .test .main ()
You can’t perform that action at this time.
0 commit comments