Skip to content

Commit 62f660c

Browse files
committed
Fix a bug
1 parent d6029c5 commit 62f660c

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

supar/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,6 @@
6565
}
6666
MODEL = {n: f"{SRC['github']}/v1.1.0/{m}.zip" for n, m in NAME.items()}
6767
CONFIG = {n: f"{SRC['github']}/v1.1.0/{m}.ini" for n, m in NAME.items()}
68+
69+
MODEL['biaffine-sdp-en'] = f"{SRC['hlt']}/v1.1.1/{NAME['biaffine-sdp-en']}.zip"
70+
MODEL['biaffine-sdp-zh'] = f"{SRC['hlt']}/v1.1.1/{NAME['biaffine-sdp-zh']}.zip"

supar/cmds/cmd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ def parse(parser):
2525
logger.info('\n' + str(args))
2626

2727
if args.mode == 'train':
28-
parser = Parser.load(args.path) if args.checkpoint else Parser.build(**args)
28+
parser = Parser.load(**args) if args.checkpoint else Parser.build(**args)
2929
parser.train(**args)
3030
elif args.mode == 'evaluate':
31-
parser = Parser.load(args.path)
31+
parser = Parser.load(**args)
3232
parser.evaluate(**args)
3333
elif args.mode == 'predict':
34-
parser = Parser.load(args.path)
34+
parser = Parser.load(**args)
3535
parser.predict(**args)

supar/parsers/parser.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,11 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
6262

6363
self.epoch, self.best_e, self.patience, self.best_metric, self.elapsed = 1, 1, patience, Metric(), timedelta()
6464
if self.args.checkpoint:
65-
self.optimizer.load_state_dict(self.args.pop('optimizer_state_dict'))
66-
self.scheduler.load_state_dict(self.args.pop('scheduler_state_dict'))
67-
for k, v in args.pop('state_args').items():
65+
self.optimizer.load_state_dict(self.checkpoint_state_dict.pop('optimizer_state_dict'))
66+
self.scheduler.load_state_dict(self.checkpoint_state_dict.pop('scheduler_state_dict'))
67+
set_rng_state(self.checkpoint_state_dict.pop('rng_state'))
68+
for k, v in self.checkpoint_state_dict.items():
6869
setattr(self, k, v)
69-
set_rng_state(args.pop('rng_state'))
7070
train.loader.batch_sampler.epoch = self.epoch
7171

7272
for epoch in range(self.epoch, args.epochs + 1):
@@ -202,7 +202,9 @@ def load(cls, path, reload=False, src=None, **kwargs):
202202
model.load_state_dict(state['state_dict'], False)
203203
model.to(args.device)
204204
transform = state['transform']
205-
return cls(args, model, transform)
205+
parser = cls(args, model, transform)
206+
parser.checkpoint_state_dict = state['checkpoint_state_dict'] if args.checkpoint else None
207+
return parser
206208

207209
def save(self, path):
208210
model = self.model
@@ -223,15 +225,16 @@ def save_checkpoint(self, path):
223225
if hasattr(model, 'module'):
224226
model = self.model.module
225227
args = model.args
226-
args.state_args = {k: getattr(self, k) for k in ['epoch', 'best_e', 'patience', 'best_metric', 'elapsed']}
227-
args.optimizer_state_dict = self.optimizer.state_dict()
228-
args.scheduler_state_dict = self.scheduler.state_dict()
229-
args.rng_state = get_rng_state()
228+
checkpoint_state_dict = {k: getattr(self, k) for k in ['epoch', 'best_e', 'patience', 'best_metric', 'elapsed']}
229+
checkpoint_state_dict.update({'optimizer_state_dict': self.optimizer.state_dict(),
230+
'scheduler_state_dict': self.scheduler.state_dict(),
231+
'rng_state': get_rng_state()})
230232
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
231233
pretrained = state_dict.pop('pretrained.weight', None)
232234
state = {'name': self.NAME,
233235
'args': args,
234236
'state_dict': state_dict,
235237
'pretrained': pretrained,
238+
'checkpoint_state_dict': checkpoint_state_dict,
236239
'transform': self.transform}
237240
torch.save(state, path, pickle_module=dill)

0 commit comments

Comments
 (0)