Skip to content

Commit b34765d

Browse files
committed
Take seqs of charts as input for ChartField
1 parent 6d4252f commit b34765d

File tree

2 files changed

+20
-25
lines changed

2 files changed

+20
-25
lines changed

supar/utils/field.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -334,15 +334,16 @@ def transform(self, sequences):
334334

335335
class ChartField(Field):
336336
r"""
337-
Field dealing with constituency trees.
338-
339-
This field receives sequences of binarized trees factorized in pre-order,
340-
and returns charts filled with labels on each constituent.
337+
Field dealing with chart inputs.
341338
342339
Examples:
343-
>>> sequence = [(0, 5, 'S'), (0, 4, 'S|<>'), (0, 1, 'NP'), (1, 4, 'VP'), (1, 2, 'VP|<>'),
344-
(2, 4, 'S+VP'), (2, 3, 'VP|<>'), (3, 4, 'NP'), (4, 5, 'S|<>')]
345-
>>> field.transform([sequence])[0]
340+
>>> chart = [[ None, 'NP', None, None, 'S|<>', 'S'],
341+
[ None, None, 'VP|<>', None, 'VP', None],
342+
[ None, None, None, 'VP|<>', 'S+VP', None],
343+
[ None, None, None, None, 'NP', None],
344+
[ None, None, None, None, None, 'S|<>'],
345+
[ None, None, None, None, None, None]]
346+
>>> field.transform([chart])[0]
346347
tensor([[ -1, 37, -1, -1, 107, 79],
347348
[ -1, -1, 120, -1, 112, -1],
348349
[ -1, -1, -1, 120, 86, -1],
@@ -352,19 +353,14 @@ class ChartField(Field):
352353
"""
353354

354355
def build(self, dataset, min_freq=1):
355-
counter = Counter(label
356-
for seq in getattr(dataset, self.name)
357-
for i, j, label in self.preprocess(seq))
356+
counter = Counter(i
357+
for chart in getattr(dataset, self.name)
358+
for row in self.preprocess(chart)
359+
for i in row if i is not None)
358360

359361
self.vocab = Vocab(counter, min_freq, self.specials, self.unk_index)
360362

361-
def transform(self, sequences):
362-
charts = []
363-
for sequence in sequences:
364-
sequence = self.preprocess(sequence)
365-
seq_len = sequence[0][1] + 1
366-
chart = torch.full((seq_len, seq_len), -1, dtype=torch.long)
367-
for i, j, label in sequence:
368-
chart[i, j] = self.vocab[label]
369-
charts.append(chart)
363+
def transform(self, charts):
364+
charts = [self.preprocess(chart) for chart in charts]
365+
charts = [torch.tensor([[self.vocab[i] if i is not None else -1 for i in row] for row in chart]) for chart in charts]
370366
return charts

supar/utils/transform.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -673,12 +673,11 @@ class TreeSentence(Sentence):
673673
def __init__(self, transform, tree):
674674
super().__init__(transform)
675675

676-
# the values contain words, pos tags, raw trees, and spans
677-
# the tree is first left-binarized before factorized
678-
# spans are the factorization of tree traversed in pre-order
679-
self.values = [*zip(*tree.pos()),
680-
tree,
681-
Tree.factorize(Tree.binarize(tree)[0])]
676+
words, tags = zip(*tree.pos())
677+
chart = [[None]*(len(words)+1) for _ in range(len(words)+1)]
678+
for i, j, label in Tree.factorize(Tree.binarize(tree)[0]):
679+
chart[i][j] = label
680+
self.values = [words, tags, tree, chart]
682681

683682
def __repr__(self):
684683
return self.values[-2].pformat(1000000)

0 commit comments

Comments
 (0)