Skip to content

Commit 97a646d

Browse files
committed
Checkpoint support
1 parent bc6a155 commit 97a646d

File tree

12 files changed

+83
-34
lines changed

12 files changed

+83
-34
lines changed

supar/cmds/biaffine_dep.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def main():
1717
subparser = subparsers.add_parser('train', help='Train a parser.')
1818
subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'bert'], nargs='+', help='features to use')
1919
subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
20+
subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
2021
subparser.add_argument('--encoder', choices=['lstm', 'bert'], default='lstm', help='encoder to use')
2122
subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
2223
subparser.add_argument('--max-len', type=int, help='max length of the sentences')

supar/cmds/biaffine_sdp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def main():
1414
subparser = subparsers.add_parser('train', help='Train a parser.')
1515
subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'lemma', 'bert'], nargs='+', help='features to use')
1616
subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
17+
subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
1718
subparser.add_argument('--encoder', choices=['lstm', 'bert'], default='lstm', help='encoder to use')
1819
subparser.add_argument('--max-len', type=int, help='max length of the sentences')
1920
subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')

supar/cmds/cmd.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,15 @@ def parse(parser):
2121
torch.set_num_threads(args.threads)
2222
torch.manual_seed(args.seed)
2323
init_device(args.device, args.local_rank)
24-
init_logger(logger, f"{args.path}.{args.mode}.log")
24+
init_logger(logger, f"{args.path}.{args.mode}.log", 'a' if args.get('checkpoint') else 'w')
2525
logger.info('\n' + str(args))
2626

2727
if args.mode == 'train':
28-
parser = Parser.build(**args)
28+
parser = Parser.load(**args) if args.checkpoint else Parser.build(**args)
2929
parser.train(**args)
3030
elif args.mode == 'evaluate':
31-
parser = Parser.load(args.path)
31+
parser = Parser.load(**args)
3232
parser.evaluate(**args)
3333
elif args.mode == 'predict':
34-
parser = Parser.load(args.path)
34+
parser = Parser.load(**args)
3535
parser.predict(**args)

supar/cmds/crf2o_dep.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def main():
1818
subparser = subparsers.add_parser('train', help='Train a parser.')
1919
subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'bert'], nargs='+', help='features to use')
2020
subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
21+
subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
2122
subparser.add_argument('--encoder', choices=['lstm', 'bert'], default='lstm', help='encoder to use')
2223
subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
2324
subparser.add_argument('--max-len', type=int, help='max length of the sentences')

supar/cmds/crf_con.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def main():
1515
subparser = subparsers.add_parser('train', help='Train a parser.')
1616
subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'bert'], nargs='+', help='features to use')
1717
subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
18+
subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
1819
subparser.add_argument('--encoder', choices=['lstm', 'bert'], default='lstm', help='encoder to use')
1920
subparser.add_argument('--max-len', type=int, help='max length of the sentences')
2021
subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')

supar/cmds/crf_dep.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def main():
1818
subparser = subparsers.add_parser('train', help='Train a parser.')
1919
subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'bert'], nargs='+', help='features to use')
2020
subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
21+
subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
2122
subparser.add_argument('--encoder', choices=['lstm', 'bert'], default='lstm', help='encoder to use')
2223
subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
2324
subparser.add_argument('--max-len', type=int, help='max length of the sentences')

supar/cmds/vi_con.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def main():
1414
subparser = subparsers.add_parser('train', help='Train a parser.')
1515
subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'bert'], nargs='+', help='features to use')
1616
subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
17+
subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
1718
subparser.add_argument('--encoder', choices=['lstm', 'bert'], default='lstm', help='encoder to use')
1819
subparser.add_argument('--max-len', type=int, help='max length of the sentences')
1920
subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')

supar/cmds/vi_dep.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ def main():
1717
subparser = subparsers.add_parser('train', help='Train a parser.')
1818
subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'bert'], nargs='+', help='features to use')
1919
subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
20+
subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
2021
subparser.add_argument('--encoder', choices=['lstm', 'bert'], default='lstm', help='encoder to use')
2122
subparser.add_argument('--punct', action='store_true', help='whether to include punctuation')
2223
subparser.add_argument('--max-len', type=int, help='max length of the sentences')

supar/cmds/vi_sdp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def main():
1414
subparser = subparsers.add_parser('train', help='Train a parser.')
1515
subparser.add_argument('--feat', '-f', choices=['tag', 'char', 'lemma', 'bert'], nargs='+', help='features to use')
1616
subparser.add_argument('--build', '-b', action='store_true', help='whether to build the model first')
17+
subparser.add_argument('--checkpoint', action='store_true', help='whether to load a checkpoint to restore training')
1718
subparser.add_argument('--encoder', choices=['lstm', 'bert'], default='lstm', help='encoder to use')
1819
subparser.add_argument('--max-len', type=int, help='max length of the sentences')
1920
subparser.add_argument('--buckets', default=32, type=int, help='max num of buckets to use')

supar/parsers/parser.py

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch.distributed as dist
1010
from supar.utils import Config, Dataset
1111
from supar.utils.field import Field
12-
from supar.utils.fn import download
12+
from supar.utils.fn import download, get_rng_state, set_rng_state
1313
from supar.utils.logging import init_logger, logger
1414
from supar.utils.metric import Metric
1515
from supar.utils.parallel import DistributedDataParallel as DDP
@@ -34,15 +34,13 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
3434
init_logger(logger, verbose=args.verbose)
3535

3636
self.transform.train()
37+
batch_size = batch_size // update_steps
3738
if dist.is_initialized():
38-
args.batch_size = args.batch_size // dist.get_world_size()
39+
batch_size = batch_size // dist.get_world_size()
3940
logger.info("Loading the data")
40-
train = Dataset(self.transform, args.train, **args)
41-
dev = Dataset(self.transform, args.dev)
42-
test = Dataset(self.transform, args.test)
43-
train.build(args.batch_size//args.update_steps, args.buckets, True, dist.is_initialized())
44-
dev.build(args.batch_size, args.buckets)
45-
test.build(args.batch_size, args.buckets)
41+
train = Dataset(self.transform, args.train, **args).build(batch_size, buckets, True, dist.is_initialized())
42+
dev = Dataset(self.transform, args.dev).build(batch_size, buckets)
43+
test = Dataset(self.transform, args.test).build(batch_size, buckets)
4644
logger.info(f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n")
4745

4846
if args.encoder == 'lstm':
@@ -60,10 +58,16 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
6058
if dist.is_initialized():
6159
self.model = DDP(self.model, device_ids=[args.local_rank], find_unused_parameters=True)
6260

63-
elapsed = timedelta()
64-
best_e, best_metric = 1, Metric()
61+
self.epoch, self.best_e, self.patience, self.best_metric, self.elapsed = 1, 1, patience, Metric(), timedelta()
62+
if self.args.checkpoint:
63+
self.optimizer.load_state_dict(self.checkpoint_state_dict.pop('optimizer_state_dict'))
64+
self.scheduler.load_state_dict(self.checkpoint_state_dict.pop('scheduler_state_dict'))
65+
set_rng_state(self.checkpoint_state_dict.pop('rng_state'))
66+
for k, v in self.checkpoint_state_dict.items():
67+
setattr(self, k, v)
68+
train.loader.batch_sampler.epoch = self.epoch
6569

66-
for epoch in range(1, args.epochs + 1):
70+
for epoch in range(self.epoch, args.epochs + 1):
6771
start = datetime.now()
6872

6973
logger.info(f"Epoch {epoch} / {args.epochs}:")
@@ -74,22 +78,26 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
7478
logger.info(f"{'test:':5} loss: {loss:.4f} - {test_metric}")
7579

7680
t = datetime.now() - start
77-
if dev_metric > best_metric:
78-
best_e, best_metric = epoch, dev_metric
81+
self.epoch += 1
82+
self.patience -= 1
83+
self.elapsed += t
84+
85+
if dev_metric > self.best_metric:
86+
self.best_e, self.patience, self.best_metric = epoch, patience, dev_metric
7987
if is_master():
80-
self.save(args.path)
88+
self.save_checkpoint(args.path)
8189
logger.info(f"{t}s elapsed (saved)\n")
8290
else:
8391
logger.info(f"{t}s elapsed\n")
84-
elapsed += t
85-
if epoch - best_e >= args.patience:
92+
if self.patience < 1:
8693
break
8794
loss, metric = self.load(**args)._evaluate(test.loader)
95+
self.save(args.path)
8896

89-
logger.info(f"Epoch {best_e} saved")
90-
logger.info(f"{'dev:':5} {best_metric}")
97+
logger.info(f"Epoch {self.best_e} saved")
98+
logger.info(f"{'dev:':5} {self.best_metric}")
9199
logger.info(f"{'test:':5} {metric}")
92-
logger.info(f"{elapsed}s elapsed, {elapsed / epoch}s/epoch")
100+
logger.info(f"{self.elapsed}s elapsed, {self.elapsed / epoch}s/epoch")
93101

94102
def evaluate(self, data, buckets=8, batch_size=5000, **kwargs):
95103
args = self.args.update(locals())
@@ -98,7 +106,7 @@ def evaluate(self, data, buckets=8, batch_size=5000, **kwargs):
98106
self.transform.train()
99107
logger.info("Loading the data")
100108
dataset = Dataset(self.transform, data)
101-
dataset.build(args.batch_size, args.buckets)
109+
dataset.build(batch_size, buckets)
102110
logger.info(f"\n{dataset}")
103111

104112
logger.info("Evaluating the dataset")
@@ -120,7 +128,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=F
120128

121129
logger.info("Loading the data")
122130
dataset = Dataset(self.transform, data, lang=lang)
123-
dataset.build(args.batch_size, args.buckets)
131+
dataset.build(batch_size, buckets)
124132
logger.info(f"\n{dataset}")
125133

126134
logger.info("Making predictions on the dataset")
@@ -153,7 +161,7 @@ def build(cls, path, **kwargs):
153161
raise NotImplementedError
154162

155163
@classmethod
156-
def load(cls, path, reload=False, src=None, **kwargs):
164+
def load(cls, path, reload=False, src=None, checkpoint=False, **kwargs):
157165
r"""
158166
Loads a parser with data fields and pretrained model parameters.
159167
@@ -169,6 +177,8 @@ def load(cls, path, reload=False, src=None, **kwargs):
169177
``'github'``: github release page.
170178
``'hlt'``: hlt homepage, only accessible from 9:00 to 18:00 (UTC+8).
171179
Default: None.
180+
checkpoint (bool):
181+
If ``True``, loads all checkpoint states to restore the training process. Default: ``False``.
172182
kwargs (dict):
173183
A dict holding unconsumed arguments for updating training configs and initializing the model.
174184
@@ -192,7 +202,9 @@ def load(cls, path, reload=False, src=None, **kwargs):
192202
model.load_state_dict(state['state_dict'], False)
193203
model.to(args.device)
194204
transform = state['transform']
195-
return cls(args, model, transform)
205+
parser = cls(args, model, transform)
206+
parser.checkpoint_state_dict = state['checkpoint_state_dict'] if args.checkpoint else None
207+
return parser
196208

197209
def save(self, path):
198210
model = self.model
@@ -207,3 +219,22 @@ def save(self, path):
207219
'pretrained': pretrained,
208220
'transform': self.transform}
209221
torch.save(state, path, pickle_module=dill)
222+
223+
def save_checkpoint(self, path):
224+
model = self.model
225+
if hasattr(model, 'module'):
226+
model = self.model.module
227+
args = model.args
228+
checkpoint_state_dict = {k: getattr(self, k) for k in ['epoch', 'best_e', 'patience', 'best_metric', 'elapsed']}
229+
checkpoint_state_dict.update({'optimizer_state_dict': self.optimizer.state_dict(),
230+
'scheduler_state_dict': self.scheduler.state_dict(),
231+
'rng_state': get_rng_state()})
232+
state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
233+
pretrained = state_dict.pop('pretrained.weight', None)
234+
state = {'name': self.NAME,
235+
'args': args,
236+
'state_dict': state_dict,
237+
'pretrained': pretrained,
238+
'checkpoint_state_dict': checkpoint_state_dict,
239+
'transform': self.transform}
240+
torch.save(state, path, pickle_module=dill)

supar/utils/data.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -131,18 +131,15 @@ def __init__(self, buckets, batch_size, shuffle=False, distributed=False):
131131
self.rank = dist.get_rank() if distributed else 0
132132
self.replicas = dist.get_world_size() if distributed else 1
133133
self.samples = sum(self.chunks) // self.replicas
134-
self.epoch = 0
134+
self.epoch = 1
135135

136136
def __iter__(self):
137137
g = torch.Generator()
138138
g.manual_seed(self.epoch)
139-
range_fn = torch.arange
139+
total, count = 0, 0
140140
# if `shuffle=True`, shuffle both the buckets and samples in each bucket
141141
# for distributed training, make sure each process generates the same random sequence at each epoch
142-
if self.shuffle:
143-
def range_fn(x):
144-
return torch.randperm(x, generator=g)
145-
total, count = 0, 0
142+
range_fn = torch.arange if not self.shuffle else lambda x: torch.randperm(x, generator=g)
146143
# TODO: more elegant way to deal with uneven data, which we directly discard right now
147144
for i in range_fn(len(self.buckets)).tolist():
148145
split_sizes = [(len(self.buckets[i]) - j - 1) // self.chunks[i] + 1 for j in range(self.chunks[i])]

supar/utils/fn.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,20 @@ def download(url, reload=False):
9999
members = f.infolist()
100100
path = os.path.join(os.path.dirname(path), members[0].filename)
101101
if len(members) != 1:
102-
raise RuntimeError('Only one file(not dir) is allowed in the zipfile.')
102+
raise RuntimeError('Only one file (not dir) is allowed in the zipfile.')
103103
if reload or not os.path.exists(path):
104104
f.extractall(os.path.dirname(path))
105105
return path
106+
107+
108+
def get_rng_state():
109+
state = {'rng_state': torch.get_rng_state()}
110+
if torch.cuda.is_available():
111+
state['cuda_rng_state'] = torch.cuda.get_rng_state()
112+
return state
113+
114+
115+
def set_rng_state(state):
116+
torch.set_rng_state(state['rng_state'])
117+
if torch.cuda.is_available():
118+
torch.cuda.set_rng_state(state['cuda_rng_state'])

0 commit comments

Comments
 (0)