Skip to content

Commit 3c1ff3b

Browse files
committed
Add tests
1 parent 573c2b7 commit 3c1ff3b

File tree

3 files changed

+93
-0
lines changed

3 files changed

+93
-0
lines changed

tests/test_alg.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from supar.utils import tarjan
4+
5+
6+
def test_tarjan():
7+
sequences = [[4, 1, 2, 0, 4, 4, 8, 6, 8],
8+
[2, 5, 0, 3, 1, 5, 8, 6, 8],
9+
[2, 5, 0, 4, 1, 5, 8, 6, 8],
10+
[2, 5, 0, 4, 1, 9, 6, 5, 7]]
11+
answers = [None, [[2, 5, 1]], [[2, 5, 1]], [[2, 5, 1], [9, 7, 6]]]
12+
for sequence, answer in zip(sequences, answers):
13+
if answer is None:
14+
assert next(tarjan(sequence), None) == answer
15+
else:
16+
assert list(tarjan(sequence)) == answer

tests/test_parse.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from supar import Parser
4+
import supar
5+
6+
7+
def test_parse():
8+
sentence = ['The', 'dog', 'chases', 'the', 'cat', '.']
9+
for name in supar.PRETRAINED:
10+
parser = Parser.load(name)
11+
parser.predict([sentence], prob=True)

tests/test_transform.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import itertools
4+
5+
import nltk
6+
from supar.utils import CoNLL, Tree
7+
8+
9+
class TestCoNLL:
10+
11+
def istree_naive(self, sequence, proj=False, multiroot=True):
12+
if proj and not CoNLL.isprojective(sequence):
13+
return False
14+
roots = [i for i, head in enumerate(sequence, 1) if head == 0]
15+
if len(roots) == 0:
16+
return False
17+
if len(roots) > 1 and not multiroot:
18+
return False
19+
sequence = [-1] + sequence
20+
21+
def track(sequence, visited, i):
22+
if visited[i]:
23+
return False
24+
visited[i] = True
25+
for j, head in enumerate(sequence[1:], 1):
26+
if head == i:
27+
track(sequence, visited, j)
28+
return True
29+
visited = [False]*len(sequence)
30+
for root in roots:
31+
if not track(sequence, visited, root):
32+
return False
33+
if any([not i for i in visited[1:]]):
34+
return False
35+
return True
36+
37+
def test_isprojective(self):
38+
assert CoNLL.isprojective([2, 4, 2, 0, 5])
39+
assert CoNLL.isprojective([3, -1, 0, -1, 3])
40+
assert not CoNLL.isprojective([2, 4, 0, 3, 4])
41+
assert not CoNLL.isprojective([4, -1, 0, -1, 4])
42+
assert not CoNLL.isprojective([2, -1, -1, 1, 0])
43+
assert not CoNLL.isprojective([0, 5, -1, -1, 4])
44+
45+
def test_istree(self):
46+
permutations = [list(sequence[:5]) for sequence in itertools.permutations(range(6))]
47+
for sequence in permutations:
48+
assert CoNLL.istree(sequence, False, False) == self.istree_naive(sequence, False, False), f"{sequence}"
49+
assert CoNLL.istree(sequence, False, True) == self.istree_naive(sequence, False, True), f"{sequence}"
50+
assert CoNLL.istree(sequence, True, False) == self.istree_naive(sequence, True, False), f"{sequence}"
51+
assert CoNLL.istree(sequence, True, True) == self.istree_naive(sequence, True, True), f"{sequence}"
52+
53+
54+
class TestTree:
55+
56+
def test_tree(self):
57+
tree = nltk.Tree.fromstring("""
58+
(TOP
59+
(S
60+
(NP (DT This) (NN time))
61+
(, ,)
62+
(NP (DT the) (NNS firms))
63+
(VP (VBD were) (ADJP (JJ ready)))
64+
(. .)))
65+
""")
66+
assert tree == Tree.build(tree, Tree.factorize(Tree.binarize(tree)[0]))

0 commit comments

Comments
 (0)