9
9
import torch .distributed as dist
10
10
from supar .utils import Config , Dataset
11
11
from supar .utils .field import Field
12
- from supar .utils .fn import download
12
+ from supar .utils .fn import download , get_rng_state , set_rng_state
13
13
from supar .utils .logging import init_logger , logger
14
14
from supar .utils .metric import Metric
15
15
from supar .utils .parallel import DistributedDataParallel as DDP
@@ -34,15 +34,13 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
34
34
init_logger (logger , verbose = args .verbose )
35
35
36
36
self .transform .train ()
37
+ batch_size = batch_size // update_steps
37
38
if dist .is_initialized ():
38
- args . batch_size = args . batch_size // dist .get_world_size ()
39
+ batch_size = batch_size // dist .get_world_size ()
39
40
logger .info ("Loading the data" )
40
- train = Dataset (self .transform , args .train , ** args )
41
- dev = Dataset (self .transform , args .dev )
42
- test = Dataset (self .transform , args .test )
43
- train .build (args .batch_size // args .update_steps , args .buckets , True , dist .is_initialized ())
44
- dev .build (args .batch_size , args .buckets )
45
- test .build (args .batch_size , args .buckets )
41
+ train = Dataset (self .transform , args .train , ** args ).build (batch_size , buckets , True , dist .is_initialized ())
42
+ dev = Dataset (self .transform , args .dev ).build (batch_size , buckets )
43
+ test = Dataset (self .transform , args .test ).build (batch_size , buckets )
46
44
logger .info (f"\n { 'train:' :6} { train } \n { 'dev:' :6} { dev } \n { 'test:' :6} { test } \n " )
47
45
48
46
if args .encoder == 'lstm' :
@@ -60,10 +58,16 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
60
58
if dist .is_initialized ():
61
59
self .model = DDP (self .model , device_ids = [args .local_rank ], find_unused_parameters = True )
62
60
63
- elapsed = timedelta ()
64
- best_e , best_metric = 1 , Metric ()
61
+ self .epoch , self .best_e , self .patience , self .best_metric , self .elapsed = 1 , 1 , patience , Metric (), timedelta ()
62
+ if self .args .checkpoint :
63
+ self .optimizer .load_state_dict (self .checkpoint_state_dict .pop ('optimizer_state_dict' ))
64
+ self .scheduler .load_state_dict (self .checkpoint_state_dict .pop ('scheduler_state_dict' ))
65
+ set_rng_state (self .checkpoint_state_dict .pop ('rng_state' ))
66
+ for k , v in self .checkpoint_state_dict .items ():
67
+ setattr (self , k , v )
68
+ train .loader .batch_sampler .epoch = self .epoch
65
69
66
- for epoch in range (1 , args .epochs + 1 ):
70
+ for epoch in range (self . epoch , args .epochs + 1 ):
67
71
start = datetime .now ()
68
72
69
73
logger .info (f"Epoch { epoch } / { args .epochs } :" )
@@ -74,22 +78,26 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
74
78
logger .info (f"{ 'test:' :5} loss: { loss :.4f} - { test_metric } " )
75
79
76
80
t = datetime .now () - start
77
- if dev_metric > best_metric :
78
- best_e , best_metric = epoch , dev_metric
81
+ self .epoch += 1
82
+ self .patience -= 1
83
+ self .elapsed += t
84
+
85
+ if dev_metric > self .best_metric :
86
+ self .best_e , self .patience , self .best_metric = epoch , patience , dev_metric
79
87
if is_master ():
80
- self .save (args .path )
88
+ self .save_checkpoint (args .path )
81
89
logger .info (f"{ t } s elapsed (saved)\n " )
82
90
else :
83
91
logger .info (f"{ t } s elapsed\n " )
84
- elapsed += t
85
- if epoch - best_e >= args .patience :
92
+ if self .patience < 1 :
86
93
break
87
94
loss , metric = self .load (** args )._evaluate (test .loader )
95
+ self .save (args .path )
88
96
89
- logger .info (f"Epoch { best_e } saved" )
90
- logger .info (f"{ 'dev:' :5} { best_metric } " )
97
+ logger .info (f"Epoch { self . best_e } saved" )
98
+ logger .info (f"{ 'dev:' :5} { self . best_metric } " )
91
99
logger .info (f"{ 'test:' :5} { metric } " )
92
- logger .info (f"{ elapsed } s elapsed, { elapsed / epoch } s/epoch" )
100
+ logger .info (f"{ self . elapsed } s elapsed, { self . elapsed / epoch } s/epoch" )
93
101
94
102
def evaluate (self , data , buckets = 8 , batch_size = 5000 , ** kwargs ):
95
103
args = self .args .update (locals ())
@@ -98,7 +106,7 @@ def evaluate(self, data, buckets=8, batch_size=5000, **kwargs):
98
106
self .transform .train ()
99
107
logger .info ("Loading the data" )
100
108
dataset = Dataset (self .transform , data )
101
- dataset .build (args . batch_size , args . buckets )
109
+ dataset .build (batch_size , buckets )
102
110
logger .info (f"\n { dataset } " )
103
111
104
112
logger .info ("Evaluating the dataset" )
@@ -120,7 +128,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
120
128
121
129
logger .info ("Loading the data" )
122
130
dataset = Dataset (self .transform , data , lang = lang )
123
- dataset .build (args . batch_size , args . buckets )
131
+ dataset .build (batch_size , buckets )
124
132
logger .info (f"\n { dataset } " )
125
133
126
134
logger .info ("Making predictions on the dataset" )
@@ -153,7 +161,7 @@ def build(cls, path, **kwargs):
153
161
raise NotImplementedError
154
162
155
163
@classmethod
156
- def load (cls , path , reload = False , src = None , ** kwargs ):
164
+ def load (cls , path , reload = False , src = None , checkpoint = False , ** kwargs ):
157
165
r"""
158
166
Loads a parser with data fields and pretrained model parameters.
159
167
@@ -169,6 +177,8 @@ def load(cls, path, reload=False, src=None, **kwargs):
169
177
``'github'``: github release page.
170
178
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
171
179
Default: None.
180
+ checkpoint (bool):
181
+ If ``True``, loads all checkpoint states to restore the training process. Default: ``False``.
172
182
kwargs (dict):
173
183
A dict holding unconsumed arguments for updating training configs and initializing the model.
174
184
@@ -192,7 +202,9 @@ def load(cls, path, reload=False, src=None, **kwargs):
192
202
model .load_state_dict (state ['state_dict' ], False )
193
203
model .to (args .device )
194
204
transform = state ['transform' ]
195
- return cls (args , model , transform )
205
+ parser = cls (args , model , transform )
206
+ parser .checkpoint_state_dict = state ['checkpoint_state_dict' ] if args .checkpoint else None
207
+ return parser
196
208
197
209
def save (self , path ):
198
210
model = self .model
@@ -207,3 +219,22 @@ def save(self, path):
207
219
'pretrained' : pretrained ,
208
220
'transform' : self .transform }
209
221
torch .save (state , path , pickle_module = dill )
222
+
223
+ def save_checkpoint (self , path ):
224
+ model = self .model
225
+ if hasattr (model , 'module' ):
226
+ model = self .model .module
227
+ args = model .args
228
+ checkpoint_state_dict = {k : getattr (self , k ) for k in ['epoch' , 'best_e' , 'patience' , 'best_metric' , 'elapsed' ]}
229
+ checkpoint_state_dict .update ({'optimizer_state_dict' : self .optimizer .state_dict (),
230
+ 'scheduler_state_dict' : self .scheduler .state_dict (),
231
+ 'rng_state' : get_rng_state ()})
232
+ state_dict = {k : v .cpu () for k , v in model .state_dict ().items ()}
233
+ pretrained = state_dict .pop ('pretrained.weight' , None )
234
+ state = {'name' : self .NAME ,
235
+ 'args' : args ,
236
+ 'state_dict' : state_dict ,
237
+ 'pretrained' : pretrained ,
238
+ 'checkpoint_state_dict' : checkpoint_state_dict ,
239
+ 'transform' : self .transform }
240
+ torch .save (state , path , pickle_module = dill )
0 commit comments