@@ -62,11 +62,11 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
62
62
63
63
self .epoch , self .best_e , self .patience , self .best_metric , self .elapsed = 1 , 1 , patience , Metric (), timedelta ()
64
64
if self .args .checkpoint :
65
- self .optimizer .load_state_dict (self .args .pop ('optimizer_state_dict' ))
66
- self .scheduler .load_state_dict (self .args .pop ('scheduler_state_dict' ))
67
- for k , v in args .pop ('state_args' ).items ():
65
+ self .optimizer .load_state_dict (self .checkpoint_state_dict .pop ('optimizer_state_dict' ))
66
+ self .scheduler .load_state_dict (self .checkpoint_state_dict .pop ('scheduler_state_dict' ))
67
+ set_rng_state (self .checkpoint_state_dict .pop ('rng_state' ))
68
+ for k , v in self .checkpoint_state_dict .items ():
68
69
setattr (self , k , v )
69
- set_rng_state (args .pop ('rng_state' ))
70
70
train .loader .batch_sampler .epoch = self .epoch
71
71
72
72
for epoch in range (self .epoch , args .epochs + 1 ):
@@ -202,7 +202,9 @@ def load(cls, path, reload=False, src=None, **kwargs):
202
202
model .load_state_dict (state ['state_dict' ], False )
203
203
model .to (args .device )
204
204
transform = state ['transform' ]
205
- return cls (args , model , transform )
205
+ parser = cls (args , model , transform )
206
+ parser .checkpoint_state_dict = state ['checkpoint_state_dict' ] if args .checkpoint else None
207
+ return parser
206
208
207
209
def save (self , path ):
208
210
model = self .model
@@ -223,15 +225,16 @@ def save_checkpoint(self, path):
223
225
if hasattr (model , 'module' ):
224
226
model = self .model .module
225
227
args = model .args
226
- args . state_args = {k : getattr (self , k ) for k in ['epoch' , 'best_e' , 'patience' , 'best_metric' , 'elapsed' ]}
227
- args . optimizer_state_dict = self .optimizer .state_dict ()
228
- args . scheduler_state_dict = self .scheduler .state_dict ()
229
- args . rng_state = get_rng_state ()
228
+ checkpoint_state_dict = {k : getattr (self , k ) for k in ['epoch' , 'best_e' , 'patience' , 'best_metric' , 'elapsed' ]}
229
+ checkpoint_state_dict . update ({ ' optimizer_state_dict' : self .optimizer .state_dict (),
230
+ 'scheduler_state_dict' : self .scheduler .state_dict (),
231
+ 'rng_state' : get_rng_state ()} )
230
232
state_dict = {k : v .cpu () for k , v in model .state_dict ().items ()}
231
233
pretrained = state_dict .pop ('pretrained.weight' , None )
232
234
state = {'name' : self .NAME ,
233
235
'args' : args ,
234
236
'state_dict' : state_dict ,
235
237
'pretrained' : pretrained ,
238
+ 'checkpoint_state_dict' : checkpoint_state_dict ,
236
239
'transform' : self .transform }
237
240
torch .save (state , path , pickle_module = dill )
0 commit comments