Skip to content

Commit 0adb1d3

Browse files
author
zysite
committed
Move the training process to Train
1 parent 091e2ae commit 0adb1d3

File tree

2 files changed

+109
-117
lines changed

2 files changed

+109
-117
lines changed

parser/cmds/train.py

Lines changed: 95 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
# -*- coding: utf-8 -*-
22

3+
import os
4+
from datetime import datetime, timedelta
35
from parser import BiaffineParser, Model
4-
from parser.utils import Corpus, Embedding, TextDataset, Vocab, collate_fn
6+
from parser.metric import Metric
7+
from parser.utils import Corpus, Embedding, Vocab
8+
from parser.utils.data import TextDataset, batchify
59

610
import torch
7-
from torch.utils.data import DataLoader
8-
9-
from config import Config
11+
from torch.optim import Adam
12+
from torch.optim.lr_scheduler import ExponentialLR
1013

1114

1215
class Train(object):
@@ -15,78 +18,110 @@ def add_subparser(self, name, parser):
1518
subparser = parser.add_parser(
1619
name, help='Train a model.'
1720
)
18-
subparser.add_argument('--ftrain', default='data/train.conllx',
21+
subparser.add_argument('--buckets', default=64, type=int,
22+
help='max num of buckets to use')
23+
subparser.add_argument('--punct', action='store_true',
24+
help='whether to include punctuation')
25+
subparser.add_argument('--ftrain', default='data/ptb/train.conllx',
1926
help='path to train file')
20-
subparser.add_argument('--fdev', default='data/dev.conllx',
27+
subparser.add_argument('--fdev', default='data/ptb/dev.conllx',
2128
help='path to dev file')
22-
subparser.add_argument('--ftest', default='data/test.conllx',
29+
subparser.add_argument('--ftest', default='data/ptb/test.conllx',
2330
help='path to test file')
2431
subparser.add_argument('--fembed', default='data/glove.6B.100d.txt',
25-
help='path to pretrained embedding file')
26-
subparser.set_defaults(func=self)
32+
help='path to pretrained embeddings')
33+
subparser.add_argument('--unk', default='unk',
34+
help='unk token in pretrained embeddings')
2735

2836
return subparser
2937

30-
def __call__(self, args):
38+
def __call__(self, config):
3139
print("Preprocess the data")
32-
train = Corpus.load(args.ftrain)
33-
dev = Corpus.load(args.fdev)
34-
test = Corpus.load(args.ftest)
35-
embed = Embedding.load(args.fembed)
36-
vocab = Vocab.from_corpus(corpus=train, min_freq=2)
37-
vocab.read_embeddings(embed=embed, unk='unk')
38-
torch.save(vocab, args.vocab)
40+
train = Corpus.load(config.ftrain)
41+
dev = Corpus.load(config.fdev)
42+
test = Corpus.load(config.ftest)
43+
if os.path.exists(config.vocab):
44+
vocab = torch.load(config.vocab)
45+
else:
46+
vocab = Vocab.from_corpus(corpus=train, min_freq=2)
47+
vocab.read_embeddings(Embedding.load(config.fembed, config.unk))
48+
torch.save(vocab, config.vocab)
49+
config.update({
50+
'n_words': vocab.n_train_words,
51+
'n_tags': vocab.n_tags,
52+
'n_rels': vocab.n_rels,
53+
'pad_index': vocab.pad_index,
54+
'unk_index': vocab.unk_index
55+
})
3956
print(vocab)
4057

4158
print("Load the dataset")
4259
trainset = TextDataset(vocab.numericalize(train))
4360
devset = TextDataset(vocab.numericalize(dev))
4461
testset = TextDataset(vocab.numericalize(test))
4562
# set the data loaders
46-
train_loader = DataLoader(dataset=trainset,
47-
batch_size=Config.batch_size,
48-
shuffle=True,
49-
collate_fn=collate_fn)
50-
dev_loader = DataLoader(dataset=devset,
51-
batch_size=Config.batch_size,
52-
collate_fn=collate_fn)
53-
test_loader = DataLoader(dataset=testset,
54-
batch_size=Config.batch_size,
55-
collate_fn=collate_fn)
56-
print(f" size of trainset: {len(trainset)}")
57-
print(f" size of devset: {len(devset)}")
58-
print(f" size of testset: {len(testset)}")
63+
train_loader = batchify(dataset=trainset,
64+
batch_size=config.batch_size,
65+
n_buckets=config.buckets,
66+
shuffle=True)
67+
dev_loader = batchify(dataset=devset,
68+
batch_size=config.batch_size,
69+
n_buckets=config.buckets)
70+
test_loader = batchify(dataset=testset,
71+
batch_size=config.batch_size,
72+
n_buckets=config.buckets)
73+
print(f"{'train:':6} {len(trainset):5} sentences in total, "
74+
f"{len(train_loader):3} batches provided")
75+
print(f"{'dev:':6} {len(devset):5} sentences in total, "
76+
f"{len(dev_loader):3} batches provided")
77+
print(f"{'test:':6} {len(testset):5} sentences in total, "
78+
f"{len(test_loader):3} batches provided")
5979

6080
print("Create the model")
61-
params = {
62-
'n_words': vocab.n_train_words,
63-
'n_embed': Config.n_embed,
64-
'n_tags': vocab.n_tags,
65-
'n_tag_embed': Config.n_tag_embed,
66-
'embed_dropout': Config.embed_dropout,
67-
'n_lstm_hidden': Config.n_lstm_hidden,
68-
'n_lstm_layers': Config.n_lstm_layers,
69-
'lstm_dropout': Config.lstm_dropout,
70-
'n_mlp_arc': Config.n_mlp_arc,
71-
'n_mlp_rel': Config.n_mlp_rel,
72-
'mlp_dropout': Config.mlp_dropout,
73-
'n_rels': vocab.n_rels,
74-
'pad_index': vocab.pad_index,
75-
'unk_index': vocab.unk_index
76-
}
77-
for k, v in params.items():
78-
print(f" {k}: {v}")
79-
network = BiaffineParser(params, vocab.embeddings)
81+
parser = BiaffineParser(config, vocab.embeddings)
8082
if torch.cuda.is_available():
81-
network = network.cuda()
82-
print(f"{network}\n")
83+
parser = parser.cuda()
84+
print(f"{parser}\n")
85+
86+
model = Model(vocab, parser)
87+
88+
total_time = timedelta()
89+
best_e, best_metric = 1, Metric()
90+
model.optimizer = Adam(model.parser.parameters(),
91+
config.lr,
92+
(config.beta_1, config.beta_2),
93+
config.epsilon)
94+
model.scheduler = ExponentialLR(model.optimizer,
95+
config.decay ** (1 / config.steps))
96+
97+
for epoch in range(1, config.epochs + 1):
98+
start = datetime.now()
99+
# train one epoch and update the parameters
100+
model.train(train_loader)
101+
102+
print(f"Epoch {epoch} / {config.epochs}:")
103+
loss, train_metric = model.evaluate(train_loader, config.punct)
104+
print(f"{'train:':6} Loss: {loss:.4f} {train_metric}")
105+
loss, dev_metric = model.evaluate(dev_loader, config.punct)
106+
print(f"{'dev:':6} Loss: {loss:.4f} {dev_metric}")
107+
loss, test_metric = model.evaluate(test_loader, config.punct)
108+
print(f"{'test:':6} Loss: {loss:.4f} {test_metric}")
109+
110+
t = datetime.now() - start
111+
# save the model if it is the best so far
112+
if dev_metric > best_metric and epoch > config.patience:
113+
best_e, best_metric = epoch, dev_metric
114+
model.parser.save(config.model + f".{best_e}")
115+
print(f"{t}s elapsed (saved)\n")
116+
else:
117+
print(f"{t}s elapsed\n")
118+
total_time += t
119+
if epoch - best_e >= config.patience:
120+
break
121+
model.parser = BiaffineParser.load(config.model + f".{best_e}")
122+
loss, metric = model.evaluate(test_loader, config.punct)
83123

84-
model = Model(vocab, network)
85-
model(loaders=(train_loader, dev_loader, test_loader),
86-
epochs=Config.epochs,
87-
patience=Config.patience,
88-
lr=Config.lr,
89-
betas=(Config.beta_1, Config.beta_2),
90-
epsilon=Config.epsilon,
91-
annealing=lambda x: Config.decay ** (x / Config.decay_steps),
92-
file=args.file)
124+
print(f"max score of dev is {best_metric.score:.2%} at epoch {best_e}")
125+
print(f"the score of test at epoch {best_e} is {metric.score:.2%}")
126+
print(f"average time of each epoch is {total_time / epoch}s")
127+
print(f"{total_time}s elapsed")

parser/model.py

Lines changed: 14 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,97 +1,54 @@
11
# -*- coding: utf-8 -*-
22

3-
from datetime import datetime, timedelta
4-
from parser.metric import AttachmentMethod
5-
from parser.parser import BiaffineParser
3+
from parser.metric import Metric
64

75
import torch
86
import torch.nn as nn
9-
import torch.optim as optim
107

118

129
class Model(object):
1310

14-
def __init__(self, vocab, network):
11+
def __init__(self, vocab, parser):
1512
super(Model, self).__init__()
1613

1714
self.vocab = vocab
18-
self.network = network
15+
self.parser = parser
1916
self.criterion = nn.CrossEntropyLoss()
2017

21-
def __call__(self, loaders, epochs, patience,
22-
lr, betas, epsilon, annealing, file):
23-
total_time = timedelta()
24-
max_e, max_metric = 0, 0.0
25-
train_loader, dev_loader, test_loader = loaders
26-
self.optimizer = optim.Adam(params=self.network.parameters(),
27-
lr=lr, betas=betas, eps=epsilon)
28-
self.scheduler = optim.lr_scheduler.LambdaLR(optimizer=self.optimizer,
29-
lr_lambda=annealing)
30-
31-
for epoch in range(1, epochs + 1):
32-
start = datetime.now()
33-
# train one epoch and update the parameters
34-
self.train(train_loader)
35-
36-
print(f"Epoch {epoch} / {epochs}:")
37-
loss, train_metric = self.evaluate(train_loader)
38-
print(f"{'train:':<6} Loss: {loss:.4f} {train_metric}")
39-
loss, dev_metric = self.evaluate(dev_loader)
40-
print(f"{'dev:':<6} Loss: {loss:.4f} {dev_metric}")
41-
loss, test_metric = self.evaluate(test_loader)
42-
print(f"{'test:':<6} Loss: {loss:.4f} {test_metric}")
43-
t = datetime.now() - start
44-
print(f"{t}s elapsed\n")
45-
total_time += t
46-
47-
# save the model if it is the best so far
48-
if dev_metric > max_metric:
49-
self.network.save(file)
50-
max_e, max_metric = epoch, dev_metric
51-
elif epoch - max_e >= patience:
52-
break
53-
self.network = BiaffineParser.load(file)
54-
loss, metric = self.evaluate(test_loader)
55-
56-
print(f"max score of dev is {max_metric.score:.2%} at epoch {max_e}")
57-
print(f"the score of test at epoch {max_e} is {metric.score:.2%}")
58-
print(f"mean time of each epoch is {total_time / epoch}s")
59-
print(f"{total_time}s elapsed")
60-
6118
def train(self, loader):
62-
self.network.train()
19+
self.parser.train()
6320

6421
for words, tags, arcs, rels in loader:
6522
self.optimizer.zero_grad()
6623

6724
mask = words.ne(self.vocab.pad_index)
6825
# ignore the first token of each sentence
6926
mask[:, 0] = 0
70-
s_arc, s_rel = self.network(words, tags)
27+
s_arc, s_rel = self.parser(words, tags)
7128
s_arc, s_rel = s_arc[mask], s_rel[mask]
7229
gold_arcs, gold_rels = arcs[mask], rels[mask]
7330

7431
loss = self.get_loss(s_arc, s_rel, gold_arcs, gold_rels)
7532
loss.backward()
76-
nn.utils.clip_grad_norm_(self.network.parameters(), 5.0)
33+
nn.utils.clip_grad_norm_(self.parser.parameters(), 5.0)
7734
self.optimizer.step()
7835
self.scheduler.step()
7936

8037
@torch.no_grad()
81-
def evaluate(self, loader, include_punct=False):
82-
self.network.eval()
38+
def evaluate(self, loader, punct=False):
39+
self.parser.eval()
8340

84-
loss, metric = 0, AttachmentMethod()
41+
loss, metric = 0, Metric()
8542

8643
for words, tags, arcs, rels in loader:
8744
mask = words.ne(self.vocab.pad_index)
8845
# ignore the first token of each sentence
8946
mask[:, 0] = 0
9047
# ignore all punctuation if not specified
91-
if not include_punct:
48+
if not punct:
9249
puncts = words.new_tensor(self.vocab.puncts)
9350
mask &= words.unsqueeze(-1).ne(puncts).all(-1)
94-
s_arc, s_rel = self.network(words, tags)
51+
s_arc, s_rel = self.parser(words, tags)
9552
s_arc, s_rel = s_arc[mask], s_rel[mask]
9653
gold_arcs, gold_rels = arcs[mask], rels[mask]
9754
pred_arcs, pred_rels = self.decode(s_arc, s_rel)
@@ -104,15 +61,15 @@ def evaluate(self, loader, include_punct=False):
10461

10562
@torch.no_grad()
10663
def predict(self, loader):
107-
self.network.eval()
64+
self.parser.eval()
10865

10966
all_arcs, all_rels = [], []
110-
for words, tags, arcs, rels in loader:
67+
for words, tags in loader:
11168
mask = words.ne(self.vocab.pad_index)
11269
# ignore the first token of each sentence
11370
mask[:, 0] = 0
11471
lens = mask.sum(dim=1).tolist()
115-
s_arc, s_rel = self.network(words, tags)
72+
s_arc, s_rel = self.parser(words, tags)
11673
s_arc, s_rel = s_arc[mask], s_rel[mask]
11774
pred_arcs, pred_rels = self.decode(s_arc, s_rel)
11875

0 commit comments

Comments
 (0)