Skip to content

Commit dd0696e

Browse files
committed
Bug fix
1 parent 4ed8b6e commit dd0696e

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

supar/parsers/parser.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,13 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
3434
init_logger(logger, verbose=args.verbose)
3535

3636
self.transform.train()
37+
batch_size = batch_size // update_steps
3738
if dist.is_initialized():
38-
args.batch_size = args.batch_size // dist.get_world_size()
39+
batch_size = batch_size // dist.get_world_size()
3940
logger.info("Loading the data")
40-
train = Dataset(self.transform, args.train, **args)
41-
dev = Dataset(self.transform, args.dev)
42-
test = Dataset(self.transform, args.test)
43-
train.build(args.batch_size//args.update_steps, args.buckets, True, dist.is_initialized())
44-
dev.build(args.batch_size, args.buckets)
45-
test.build(args.batch_size, args.buckets)
41+
train = Dataset(self.transform, args.train, **args).build(batch_size, buckets, True, dist.is_initialized())
42+
dev = Dataset(self.transform, args.dev).build(batch_size, buckets)
43+
test = Dataset(self.transform, args.test).build(batch_size, buckets)
4644
logger.info(f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n")
4745

4846
if args.encoder == 'lstm':
@@ -108,7 +106,7 @@ def evaluate(self, data, buckets=8, batch_size=5000, **kwargs):
108106
self.transform.train()
109107
logger.info("Loading the data")
110108
dataset = Dataset(self.transform, data)
111-
dataset.build(args.batch_size, args.buckets)
109+
dataset.build(batch_size, buckets)
112110
logger.info(f"\n{dataset}")
113111

114112
logger.info("Evaluating the dataset")
@@ -130,7 +128,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
130128

131129
logger.info("Loading the data")
132130
dataset = Dataset(self.transform, data, lang=lang)
133-
dataset.build(args.batch_size, args.buckets)
131+
dataset.build(batch_size, buckets)
134132
logger.info(f"\n{dataset}")
135133

136134
logger.info("Making predictions on the dataset")
@@ -163,7 +161,7 @@ def build(cls, path, **kwargs):
163161
raise NotImplementedError
164162

165163
@classmethod
166-
def load(cls, path, reload=False, src=None, **kwargs):
164+
def load(cls, path, reload=False, src=None, checkpoint=False, **kwargs):
167165
r"""
168166
Loads a parser with data fields and pretrained model parameters.
169167
@@ -179,6 +177,8 @@ def load(cls, path, reload=False, src=None, **kwargs):
179177
``'github'``: github release page.
180178
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
181179
Default: None.
180+
checkpoint (bool):
181+
If ``True``, loads all checkpoint states to restore the training process. Default: ``False``.
182182
kwargs (dict):
183183
A dict holding unconsumed arguments for updating training configs and initializing the model.
184184

0 commit comments

Comments
 (0)