Skip to content

Commit 9cb8334

Browse files
committed
Move init of optimizer/scheduler to build
1 parent 2548c24 commit 9cb8334

File tree

3 files changed

+70
-62
lines changed

3 files changed

+70
-62
lines changed

supar/parsers/constituency.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from supar.utils.logging import get_logger, progress_bar
1313
from supar.utils.metric import SpanMetric
1414
from supar.utils.transform import Tree
15+
from torch.optim import Adam
16+
from torch.optim.lr_scheduler import ExponentialLR
1517

1618
logger = get_logger(__name__)
1719

@@ -199,13 +201,22 @@ def _predict(self, loader):
199201
return preds
200202

201203
@classmethod
202-
def build(cls, path, min_freq=2, fix_len=20, **kwargs):
204+
def build(cls, path,
205+
optimizer_args={'lr': 2e-3, 'betas': (.9, .9), 'eps': 1e-12},
206+
scheduler_args={'gamma': .75**(1/5000)},
207+
min_freq=2,
208+
fix_len=20,
209+
**kwargs):
203210
r"""
204211
Build a brand-new Parser, including initialization of all data fields and model parameters.
205212
206213
Args:
207214
path (str):
208215
The path of the model to be saved.
216+
optimizer_args (dict):
217+
Arguments for creating an optimizer.
218+
scheduler_args (dict):
219+
Arguments for creating a scheduler.
209220
min_freq (str):
210221
The minimum frequency needed to include a token in the vocabulary. Default: 2.
211222
fix_len (int):
@@ -263,6 +274,11 @@ def build(cls, path, min_freq=2, fix_len=20, **kwargs):
263274
'eos_index': WORD.eos_index,
264275
'feat_pad_index': FEAT.pad_index
265276
})
277+
278+
logger.info("Building the model")
266279
model = cls.MODEL(**args)
267280
model.load_pretrained(WORD.embed).to(args.device)
268-
return cls(args, model, transform)
281+
optimizer = Adam(model.parameters(), **optimizer_args)
282+
scheduler = ExponentialLR(optimizer, **scheduler_args)
283+
284+
return cls(args, model, transform, optimizer, scheduler)

supar/parsers/dependency.py

Lines changed: 44 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from supar.utils.logging import get_logger, progress_bar
1515
from supar.utils.metric import AttachmentMetric
1616
from supar.utils.transform import CoNLL
17+
from torch.optim import Adam
18+
from torch.optim.lr_scheduler import ExponentialLR
1719

1820
logger = get_logger(__name__)
1921

@@ -168,9 +170,7 @@ def _evaluate(self, loader):
168170
mask[:, 0] = 0
169171
s_arc, s_rel = self.model(words, feats)
170172
loss = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.partial)
171-
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask,
172-
self.args.tree,
173-
self.args.proj)
173+
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
174174
if self.args.partial:
175175
mask &= arcs.ge(0)
176176
# ignore all punctuation if not specified
@@ -194,9 +194,7 @@ def _predict(self, loader):
194194
mask[:, 0] = 0
195195
lens = mask.sum(1).tolist()
196196
s_arc, s_rel = self.model(words, feats)
197-
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask,
198-
self.args.tree,
199-
self.args.proj)
197+
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
200198
arcs.extend(arc_preds[mask].split(lens))
201199
rels.extend(rel_preds[mask].split(lens))
202200
if self.args.prob:
@@ -211,13 +209,21 @@ def _predict(self, loader):
211209
return preds
212210

213211
@classmethod
214-
def build(cls, path, min_freq=2, fix_len=20, **kwargs):
212+
def build(cls, path,
213+
optimizer_args={'lr': 2e-3, 'betas': (.9, .9), 'eps': 1e-12},
214+
scheduler_args={'gamma': .75**(1/5000)},
215+
min_freq=2,
216+
fix_len=20, **kwargs):
215217
r"""
216218
Build a brand-new Parser, including initialization of all data fields and model parameters.
217219
218220
Args:
219221
path (str):
220222
The path of the model to be saved.
223+
optimizer_args (dict):
224+
Arguments for creating an optimizer.
225+
scheduler_args (dict):
226+
Arguments for creating a scheduler.
221227
min_freq (str):
222228
The minimum frequency needed to include a token in the vocabulary. Default: 2.
223229
fix_len (int):
@@ -273,9 +279,15 @@ def build(cls, path, min_freq=2, fix_len=20, **kwargs):
273279
'bos_index': WORD.bos_index,
274280
'feat_pad_index': FEAT.pad_index,
275281
})
282+
283+
logger.info("Building the model")
276284
model = cls.MODEL(**args)
277285
model.load_pretrained(WORD.embed).to(args.device)
278-
return cls(args, model, transform)
286+
287+
optimizer = Adam(model.parameters(), **optimizer_args)
288+
scheduler = ExponentialLR(optimizer, **scheduler_args)
289+
290+
return cls(args, model, transform, optimizer, scheduler)
279291

280292

281293
class CRFNPDependencyParser(BiaffineDependencyParser):
@@ -584,9 +596,7 @@ def _train(self, loader):
584596
# ignore the first token of each sentence
585597
mask[:, 0] = 0
586598
s_arc, s_rel = self.model(words, feats)
587-
loss, s_arc = self.model.loss(s_arc, s_rel, arcs, rels, mask,
588-
self.args.mbr,
589-
self.args.partial)
599+
loss, s_arc = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.mbr, self.args.partial)
590600
loss.backward()
591601
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
592602
self.optimizer.step()
@@ -612,12 +622,8 @@ def _evaluate(self, loader):
612622
# ignore the first token of each sentence
613623
mask[:, 0] = 0
614624
s_arc, s_rel = self.model(words, feats)
615-
loss, s_arc = self.model.loss(s_arc, s_rel, arcs, rels, mask,
616-
self.args.mbr,
617-
self.args.partial)
618-
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask,
619-
self.args.tree,
620-
self.args.proj)
625+
loss, s_arc = self.model.loss(s_arc, s_rel, arcs, rels, mask, self.args.mbr, self.args.partial)
626+
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
621627
if self.args.partial:
622628
mask &= arcs.ge(0)
623629
# ignore all punctuation if not specified
@@ -643,9 +649,7 @@ def _predict(self, loader):
643649
s_arc, s_rel = self.model(words, feats)
644650
if self.args.mbr:
645651
s_arc = self.model.crf(s_arc, mask, mbr=True)
646-
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask,
647-
self.args.tree,
648-
self.args.proj)
652+
arc_preds, rel_preds = self.model.decode(s_arc, s_rel, mask, self.args.tree, self.args.proj)
649653
arcs.extend(arc_preds[mask].split(lens))
650654
rels.extend(rel_preds[mask].split(lens))
651655
if self.args.prob:
@@ -780,9 +784,7 @@ def _train(self, loader):
780784
# ignore the first token of each sentence
781785
mask[:, 0] = 0
782786
s_arc, s_sib, s_rel = self.model(words, feats)
783-
loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask,
784-
self.args.mbr,
785-
self.args.partial)
787+
loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.args.mbr, self.args.partial)
786788
loss.backward()
787789
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
788790
self.optimizer.step()
@@ -808,13 +810,8 @@ def _evaluate(self, loader):
808810
# ignore the first token of each sentence
809811
mask[:, 0] = 0
810812
s_arc, s_sib, s_rel = self.model(words, feats)
811-
loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask,
812-
self.args.mbr,
813-
self.args.partial)
814-
arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask,
815-
self.args.tree,
816-
self.args.mbr,
817-
self.args.proj)
813+
loss, s_arc = self.model.loss(s_arc, s_sib, s_rel, arcs, sibs, rels, mask, self.args.mbr, self.args.partial)
814+
arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask, self.args.tree, self.args.mbr, self.args.proj)
818815
if self.args.partial:
819816
mask &= arcs.ge(0)
820817
# ignore all punctuation if not specified
@@ -840,10 +837,7 @@ def _predict(self, loader):
840837
s_arc, s_sib, s_rel = self.model(words, feats)
841838
if self.args.mbr:
842839
s_arc = self.model.crf((s_arc, s_sib), mask, mbr=True)
843-
arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask,
844-
self.args.tree,
845-
self.args.mbr,
846-
self.args.proj)
840+
arc_preds, rel_preds = self.model.decode(s_arc, s_sib, s_rel, mask, self.args.tree, self.args.mbr, self.args.proj)
847841
arcs.extend(arc_preds[mask].split(lens))
848842
rels.extend(rel_preds[mask].split(lens))
849843
if self.args.prob:
@@ -858,13 +852,21 @@ def _predict(self, loader):
858852
return preds
859853

860854
@classmethod
861-
def build(cls, path, min_freq=2, fix_len=20, **kwargs):
855+
def build(cls, path,
856+
optimizer_args={'lr': 2e-3, 'betas': (.9, .9), 'eps': 1e-12},
857+
scheduler_args={'gamma': .75**(1/5000)},
858+
min_freq=2,
859+
fix_len=20, **kwargs):
862860
r"""
863861
Build a brand-new Parser, including initialization of all data fields and model parameters.
864862
865863
Args:
866864
path (str):
867865
The path of the model to be saved.
866+
optimizer_args (dict):
867+
Arguments for creating an optimizer.
868+
scheduler_args (dict):
869+
Arguments for creating a scheduler.
868870
min_freq (str):
869871
The minimum frequency needed to include a token in the vocabulary. Default: 2.
870872
fix_len (int):
@@ -921,6 +923,12 @@ def build(cls, path, min_freq=2, fix_len=20, **kwargs):
921923
'bos_index': WORD.bos_index,
922924
'feat_pad_index': FEAT.pad_index
923925
})
926+
927+
logger.info("Building the model")
924928
model = cls.MODEL(**args)
925929
model = model.load_pretrained(WORD.embed).to(args.device)
926-
return cls(args, model, transform)
930+
931+
optimizer = Adam(model.parameters(), **optimizer_args)
932+
scheduler = ExponentialLR(optimizer, **scheduler_args)
933+
934+
return cls(args, model, transform, optimizer, scheduler)

supar/parsers/parser.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,34 +12,21 @@
1212
from supar.utils.metric import Metric
1313
from supar.utils.parallel import DistributedDataParallel as DDP
1414
from supar.utils.parallel import is_master
15-
from torch.optim import Adam
16-
from torch.optim.lr_scheduler import ExponentialLR
1715

1816

1917
class Parser(object):
2018

2119
NAME = None
2220
MODEL = None
2321

24-
def __init__(self, args, model, transform):
22+
def __init__(self, args, model, transform, optimizer=None, scheduler=None):
2523
self.args = args
2624
self.model = model
2725
self.transform = transform
26+
self.optimizer = optimizer
27+
self.scheduler = scheduler
2828

29-
def train(self, train, dev, test,
30-
buckets=32,
31-
batch_size=5000,
32-
lr=2e-3,
33-
mu=.9,
34-
nu=.9,
35-
epsilon=1e-12,
36-
clip=5.0,
37-
decay=.75,
38-
decay_steps=5000,
39-
epochs=5000,
40-
patience=100,
41-
verbose=True,
42-
**kwargs):
29+
def train(self, train, dev, test, buckets=32, batch_size=5000, clip=5.0, epochs=5000, patience=100, **kwargs):
4330
args = self.args.update(locals())
4431
init_logger(logger, verbose=args.verbose)
4532

@@ -55,11 +42,8 @@ def train(self, train, dev, test,
5542
test.build(args.batch_size, args.buckets)
5643
logger.info(f"\n{'train:':6} {train}\n{'dev:':6} {dev}\n{'test:':6} {test}\n")
5744

58-
logger.info(f"{self.model}\n")
5945
if dist.is_initialized():
6046
self.model = DDP(self.model, device_ids=[args.local_rank], find_unused_parameters=True)
61-
self.optimizer = Adam(self.model.parameters(), args.lr, (args.mu, args.nu), args.epsilon)
62-
self.scheduler = ExponentialLR(self.optimizer, args.decay**(1/args.decay_steps))
6347

6448
elapsed = timedelta()
6549
best_e, best_metric = 1, Metric()
@@ -70,9 +54,9 @@ def train(self, train, dev, test,
7054
logger.info(f"Epoch {epoch} / {args.epochs}:")
7155
self._train(train.loader)
7256
loss, dev_metric = self._evaluate(dev.loader)
73-
logger.info(f"{'dev:':6} - loss: {loss:.4f} - {dev_metric}")
57+
logger.info(f"{'dev:':6} loss: {loss:.4f} - {dev_metric}")
7458
loss, test_metric = self._evaluate(test.loader)
75-
logger.info(f"{'test:':6} - loss: {loss:.4f} - {test_metric}")
59+
logger.info(f"{'test:':6} loss: {loss:.4f} - {test_metric}")
7660

7761
t = datetime.now() - start
7862
# save the model if it is the best so far
@@ -89,8 +73,8 @@ def train(self, train, dev, test,
8973
loss, metric = self.load(**args)._evaluate(test.loader)
9074

9175
logger.info(f"Epoch {best_e} saved")
92-
logger.info(f"{'dev:':6} - {best_metric}")
93-
logger.info(f"{'test:':6} - {metric}")
76+
logger.info(f"{'dev:':6} {best_metric}")
77+
logger.info(f"{'test:':6} {metric}")
9478
logger.info(f"{elapsed}s elapsed, {elapsed / epoch}s/epoch")
9579

9680
def evaluate(self, data, buckets=8, batch_size=5000, **kwargs):

0 commit comments

Comments
 (0)