@@ -92,6 +92,7 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
92
92
if self .patience < 1 :
93
93
break
94
94
dist .barrier ()
95
+ args .device = args .local_rank
95
96
parser = self .load (** args )
96
97
loss , metric = parser ._evaluate (test .loader )
97
98
# only allow the master device to save models
@@ -163,7 +164,7 @@ def build(cls, path, **kwargs):
163
164
raise NotImplementedError
164
165
165
166
@classmethod
166
- def load (cls , path , reload = False , src = 'github' , checkpoint = False , ** kwargs ):
167
+ def load (cls , path , reload = False , src = 'github' , checkpoint = False , device = None , ** kwargs ):
167
168
r"""
168
169
Loads a parser with data fields and pretrained model parameters.
169
170
@@ -181,6 +182,9 @@ def load(cls, path, reload=False, src='github', checkpoint=False, **kwargs):
181
182
Default: ``'github'``.
182
183
checkpoint (bool):
183
184
If ``True``, loads all checkpoint states to restore the training process. Default: ``False``.
185
+ device (:class:`torch.device`):
186
+ The desired device of the model parameters.
187
+ If ``None``, uses the default GPU device (if available), otherwise uses the CPU device. Default: ``None``.
184
188
kwargs (dict):
185
189
A dict holding unconsumed arguments for updating training configs and initializing the model.
186
190
@@ -191,17 +195,20 @@ def load(cls, path, reload=False, src='github', checkpoint=False, **kwargs):
191
195
"""
192
196
193
197
args = Config (** locals ())
194
- args .device = 'cuda' if torch .cuda .is_available () else 'cpu'
195
- state = torch .load (path if os .path .exists (path ) else download (supar .MODEL [src ].get (path , path ), reload = reload ))
198
+ if device is None :
199
+ device = 'cuda' if torch .cuda .is_available () else 'cpu'
200
+ if not os .path .exists (path ):
201
+ path = download (supar .MODEL [src ].get (path , path ), reload = reload )
202
+ state = torch .load (path )
196
203
cls = supar .PARSER [state ['name' ]] if cls .NAME is None else cls
197
204
args = state ['args' ].update (args )
198
205
model = cls .MODEL (** args )
199
206
model .load_pretrained (state ['pretrained' ])
200
207
model .load_state_dict (state ['state_dict' ], False )
201
- model .to (args . device )
208
+ model .to (device )
202
209
transform = state ['transform' ]
203
210
parser = cls (args , model , transform )
204
- parser .checkpoint_state_dict = state ['checkpoint_state_dict' ] if args . checkpoint else None
211
+ parser .checkpoint_state_dict = state ['checkpoint_state_dict' ] if checkpoint else None
205
212
return parser
206
213
207
214
def save (self , path ):
0 commit comments