@@ -34,15 +34,13 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
34
34
init_logger (logger , verbose = args .verbose )
35
35
36
36
self .transform .train ()
37
+ batch_size = batch_size // update_steps
37
38
if dist .is_initialized ():
38
- args . batch_size = args . batch_size // dist .get_world_size ()
39
+ batch_size = batch_size // dist .get_world_size ()
39
40
logger .info ("Loading the data" )
40
- train = Dataset (self .transform , args .train , ** args )
41
- dev = Dataset (self .transform , args .dev )
42
- test = Dataset (self .transform , args .test )
43
- train .build (args .batch_size // args .update_steps , args .buckets , True , dist .is_initialized ())
44
- dev .build (args .batch_size , args .buckets )
45
- test .build (args .batch_size , args .buckets )
41
+ train = Dataset (self .transform , args .train , ** args ).build (batch_size , buckets , True , dist .is_initialized ())
42
+ dev = Dataset (self .transform , args .dev ).build (batch_size , buckets )
43
+ test = Dataset (self .transform , args .test ).build (batch_size , buckets )
46
44
logger .info (f"\n { 'train:' :6} { train } \n { 'dev:' :6} { dev } \n { 'test:' :6} { test } \n " )
47
45
48
46
if args .encoder == 'lstm' :
@@ -108,7 +106,7 @@ def evaluate(self, data, buckets=8, batch_size=5000, **kwargs):
108
106
self .transform .train ()
109
107
logger .info ("Loading the data" )
110
108
dataset = Dataset (self .transform , data )
111
- dataset .build (args . batch_size , args . buckets )
109
+ dataset .build (batch_size , buckets )
112
110
logger .info (f"\n { dataset } " )
113
111
114
112
logger .info ("Evaluating the dataset" )
@@ -130,7 +128,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
130
128
131
129
logger .info ("Loading the data" )
132
130
dataset = Dataset (self .transform , data , lang = lang )
133
- dataset .build (args . batch_size , args . buckets )
131
+ dataset .build (batch_size , buckets )
134
132
logger .info (f"\n { dataset } " )
135
133
136
134
logger .info ("Making predictions on the dataset" )
@@ -163,7 +161,7 @@ def build(cls, path, **kwargs):
163
161
raise NotImplementedError
164
162
165
163
@classmethod
166
- def load (cls , path , reload = False , src = None , ** kwargs ):
164
+ def load (cls , path , reload = False , src = None , checkpoint = False , ** kwargs ):
167
165
r"""
168
166
Loads a parser with data fields and pretrained model parameters.
169
167
@@ -179,6 +177,8 @@ def load(cls, path, reload=False, src=None, **kwargs):
179
177
``'github'``: github release page.
180
178
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
181
179
Default: None.
180
+ checkpoint (bool):
181
+ If ``True``, loads all checkpoint states to restore the training process. Default: ``False``.
182
182
kwargs (dict):
183
183
A dict holding unconsumed arguments for updating training configs and initializing the model.
184
184
0 commit comments