@@ -71,7 +71,10 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
71
71
start = datetime .now ()
72
72
73
73
logger .info (f"Epoch { epoch } / { args .epochs } :" )
74
- with self .model .join ():
74
+ if dist .is_initialized ():
75
+ with self .model .join ():
76
+ self ._train (train .loader )
77
+ else :
75
78
self ._train (train .loader )
76
79
loss , dev_metric = self ._evaluate (dev .loader )
77
80
logger .info (f"{ 'dev:' :5} loss: { loss :.4f} - { dev_metric } " )
@@ -92,7 +95,8 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
92
95
logger .info (f"{ t } s elapsed\n " )
93
96
if self .patience < 1 :
94
97
break
95
- dist .barrier ()
98
+ if dist .is_initialized ():
99
+ dist .barrier ()
96
100
args .device = args .local_rank
97
101
parser = self .load (** args )
98
102
loss , metric = parser ._evaluate (test .loader )
@@ -165,7 +169,7 @@ def build(cls, path, **kwargs):
165
169
raise NotImplementedError
166
170
167
171
@classmethod
168
- def load (cls , path , reload = False , src = 'github' , checkpoint = False , device = None , ** kwargs ):
172
+ def load (cls , path , reload = False , src = 'github' , checkpoint = False , ** kwargs ):
169
173
r"""
170
174
Loads a parser with data fields and pretrained model parameters.
171
175
@@ -183,9 +187,6 @@ def load(cls, path, reload=False, src='github', checkpoint=False, device=None, *
183
187
Default: ``'github'``.
184
188
checkpoint (bool):
185
189
If ``True``, loads all checkpoint states to restore the training process. Default: ``False``.
186
- device (:class:`torch.device`):
187
- The desired device of the model parameters.
188
- If ``None``, uses the default GPU device (if available), otherwise uses the CPU device. Default: ``None``.
189
190
kwargs (dict):
190
191
A dict holding unconsumed arguments for updating training configs and initializing the model.
191
192
@@ -196,11 +197,10 @@ def load(cls, path, reload=False, src='github', checkpoint=False, device=None, *
196
197
"""
197
198
198
199
args = Config (** locals ())
199
- if args .device is None :
200
- args .device = 'cuda' if torch .cuda .is_available () else 'cpu'
200
+ args .device = 'cuda' if torch .cuda .is_available () else 'cpu'
201
201
if not os .path .exists (path ):
202
202
path = download (supar .MODEL [src ].get (path , path ), reload = reload )
203
- state = torch .load (path )
203
+ state = torch .load (path , map_location = 'cpu' )
204
204
cls = supar .PARSER [state ['name' ]] if cls .NAME is None else cls
205
205
args = state ['args' ].update (args )
206
206
model = cls .MODEL (** args )
0 commit comments