Skip to content

Commit 052de92

Browse files
committed
Fix bug of nltk.Tree construction (yzhangcs#65)
1 parent 1e3e051 commit 052de92

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

supar/utils/transform.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def tgt(self):
487487
return self.CHART,
488488

489489
@classmethod
490-
def totree(cls, tokens, root=''):
490+
def totree(cls, tokens, root='', special_tokens={'(': '-LRB-', ')': '-RRB-'}):
491491
r"""
492492
Converts a list of tokens to a :class:`nltk.tree.Tree`.
493493
Missing fields are filled with underscores.
@@ -497,6 +497,9 @@ def totree(cls, tokens, root=''):
497497
This can be either a list of words or word/pos pairs.
498498
root (str):
499499
The root label of the tree. Default: ''.
500+
special_tokens (dict):
501+
A dict for normalizing some special tokens to avoid tree construction crash.
502+
Default: {'(': '-LRB-', ')': '-RRB-'}.
500503
501504
Returns:
502505
A :class:`nltk.tree.Tree` object.
@@ -508,8 +511,15 @@ def totree(cls, tokens, root=''):
508511

509512
if isinstance(tokens[0], str):
510513
tokens = [(token, '_') for token in tokens]
511-
tree = ' '.join([f"( ({pos} {word}))" for word, pos in tokens])
512-
return nltk.Tree.fromstring(f"({root} {tree})")
514+
mapped = []
515+
for i, (word, pos) in enumerate(tokens):
516+
if word in special_tokens:
517+
tokens[i] = (special_tokens[word], pos)
518+
mapped.append((i, word))
519+
tree = nltk.Tree.fromstring(f"({root} {' '.join([f'( ({pos} {word}))' for word, pos in tokens])})")
520+
for i, word in mapped:
521+
tree[i][0][0] = word
522+
return tree
513523

514524
@classmethod
515525
def binarize(cls, tree):

0 commit comments

Comments
 (0)