Skip to content

Commit c9f8b13

Browse files
committed
Transform a tree into a single chart
1 parent 3a1dc5a commit c9f8b13

File tree

3 files changed

+25
-44
lines changed

3 files changed

+25
-44
lines changed

supar/models/constituency.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -206,17 +206,15 @@ def forward(self, words, feats):
206206

207207
return s_span, s_label
208208

209-
def loss(self, s_span, s_label, spans, labels, mask, mbr=True):
209+
def loss(self, s_span, s_label, charts, mask, mbr=True):
210210
r"""
211211
Args:
212212
s_span (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
213213
Scores of all spans
214214
s_label (~torch.Tensor): ``[batch_size, seq_len, seq_len, n_labels]``.
215215
Scores of all labels on each span.
216-
spans (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
217-
The tensor of gold-standard spans. ``True`` denotes there exist a span.
218-
labels (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
219-
The tensor of gold-standard labels.
216+
charts (~torch.LongTensor): ``[batch_size, seq_len, seq_len]``.
217+
The tensor of gold-standard labels, in which positions without labels are filled with -1.
220218
mask (~torch.BoolTensor): ``[batch_size, seq_len, seq_len]``.
221219
The mask for covering the unpadded tokens in each chart.
222220
mbr (bool):
@@ -228,9 +226,9 @@ def loss(self, s_span, s_label, spans, labels, mask, mbr=True):
228226
original span scores of shape ``[batch_size, seq_len, seq_len]`` if ``mbr=False``, or marginals otherwise.
229227
"""
230228

231-
span_mask = spans & mask
232-
span_loss, span_probs = self.crf(s_span, mask, spans, mbr)
233-
label_loss = self.criterion(s_label[span_mask], labels[span_mask])
229+
span_mask = charts.ge(0) & mask
230+
span_loss, span_probs = self.crf(s_span, mask, span_mask, mbr)
231+
label_loss = self.criterion(s_label[span_mask], charts[span_mask])
234232
loss = span_loss + label_loss
235233

236234
return loss, span_probs
@@ -252,5 +250,4 @@ def decode(self, s_span, s_label, mask):
252250

253251
span_preds = cky(s_span, mask)
254252
label_preds = s_label.argmax(-1).tolist()
255-
return [[(i, j, labels[i][j]) for i, j in spans]
256-
for spans, labels in zip(span_preds, label_preds)]
253+
return [[(i, j, labels[i][j]) for i, j in spans] for spans, labels in zip(span_preds, label_preds)]

supar/parsers/crf_constituency.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,15 +133,15 @@ def _train(self, loader):
133133

134134
bar = progress_bar(loader)
135135

136-
for words, feats, trees, (spans, labels) in bar:
136+
for words, feats, trees, charts in bar:
137137
self.optimizer.zero_grad()
138138

139139
batch_size, seq_len = words.shape
140140
lens = words.ne(self.args.pad_index).sum(1) - 1
141141
mask = lens.new_tensor(range(seq_len - 1)) < lens.view(-1, 1, 1)
142142
mask = mask & mask.new_ones(seq_len-1, seq_len-1).triu_(1)
143143
s_span, s_label = self.model(words, feats)
144-
loss, _ = self.model.loss(s_span, s_label, spans, labels, mask, self.args.mbr)
144+
loss, _ = self.model.loss(s_span, s_label, charts, mask, self.args.mbr)
145145
loss.backward()
146146
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
147147
self.optimizer.step()
@@ -155,13 +155,13 @@ def _evaluate(self, loader):
155155

156156
total_loss, metric = 0, BracketMetric()
157157

158-
for words, feats, trees, (spans, labels) in loader:
158+
for words, feats, trees, charts in loader:
159159
batch_size, seq_len = words.shape
160160
lens = words.ne(self.args.pad_index).sum(1) - 1
161161
mask = lens.new_tensor(range(seq_len - 1)) < lens.view(-1, 1, 1)
162162
mask = mask & mask.new_ones(seq_len-1, seq_len-1).triu_(1)
163163
s_span, s_label = self.model(words, feats)
164-
loss, s_span = self.model.loss(s_span, s_label, spans, labels, mask, self.args.mbr)
164+
loss, s_span = self.model.loss(s_span, s_label, charts, mask, self.args.mbr)
165165
chart_preds = self.model.decode(s_span, s_label, mask)
166166
# since the evaluation relies on terminals,
167167
# the tree should be first built and then factorized

supar/utils/field.py

Lines changed: 14 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -321,26 +321,18 @@ class ChartField(Field):
321321
Field dealing with constituency trees.
322322
323323
This field receives sequences of binarized trees factorized in pre-order,
324-
and returns two tensors representing the bracketing trees and labels on each constituent respectively.
324+
and returns charts filled with labels on each constituent.
325325
326326
Examples:
327327
>>> sequence = [(0, 5, 'S'), (0, 4, 'S|<>'), (0, 1, 'NP'), (1, 4, 'VP'), (1, 2, 'VP|<>'),
328328
(2, 4, 'S+VP'), (2, 3, 'VP|<>'), (3, 4, 'NP'), (4, 5, 'S|<>')]
329-
>>> spans, labels = field.transform([sequence])[0] # this example field is built from ptb
330-
>>> spans
331-
tensor([[False, True, False, False, True, True],
332-
[False, False, True, False, True, False],
333-
[False, False, False, True, True, False],
334-
[False, False, False, False, True, False],
335-
[False, False, False, False, False, True],
336-
[False, False, False, False, False, False]])
337-
>>> labels
338-
tensor([[ 0, 37, 0, 0, 107, 79],
339-
[ 0, 0, 120, 0, 112, 0],
340-
[ 0, 0, 0, 120, 86, 0],
341-
[ 0, 0, 0, 0, 37, 0],
342-
[ 0, 0, 0, 0, 0, 107],
343-
[ 0, 0, 0, 0, 0, 0]])
329+
>>> field.transform([sequence])[0]
330+
tensor([[ -1, 37, -1, -1, 107, 79],
331+
[ -1, -1, 120, -1, 112, -1],
332+
[ -1, -1, -1, 120, 86, -1],
333+
[ -1, -1, -1, -1, 37, -1],
334+
[ -1, -1, -1, -1, -1, 107],
335+
[ -1, -1, -1, -1, -1, -1]])
344336
"""
345337

346338
def build(self, dataset, min_freq=1):
@@ -351,20 +343,12 @@ def build(self, dataset, min_freq=1):
351343
self.vocab = Vocab(counter, min_freq, self.specials, self.unk_index)
352344

353345
def transform(self, sequences):
354-
sequences = [self.preprocess(seq) for seq in sequences]
355-
spans, labels = [], []
356-
346+
charts = []
357347
for sequence in sequences:
348+
sequence = self.preprocess(sequence)
358349
seq_len = sequence[0][1] + 1
359-
span_chart = torch.full((seq_len, seq_len), self.pad_index, dtype=torch.bool)
360-
label_chart = torch.full((seq_len, seq_len), self.pad_index, dtype=torch.long)
350+
chart = torch.full((seq_len, seq_len), -1, dtype=torch.long)
361351
for i, j, label in sequence:
362-
span_chart[i, j] = 1
363-
label_chart[i, j] = self.vocab[label]
364-
spans.append(span_chart)
365-
labels.append(label_chart)
366-
367-
return list(zip(spans, labels))
368-
369-
def compose(self, sequences):
370-
return [pad(i).to(self.device) for i in zip(*sequences)]
352+
chart[i, j] = self.vocab[label]
353+
charts.append(chart)
354+
return charts

0 commit comments

Comments
 (0)