Skip to content

Commit f410ffd

Browse files
committed
Update docs
1 parent 6f06f95 commit f410ffd

File tree

2 files changed

+49
-65
lines changed

2 files changed

+49
-65
lines changed

supar/utils/alg.py

Lines changed: 42 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def tarjan(sequence):
8686
List of head indices.
8787
8888
Yields:
89-
A list of indices that make up a SCC. All self-loops are ignored.
89+
A list of indices making up a SCC. All self-loops are ignored.
9090
9191
Examples:
9292
>>> next(tarjan([2, 5, 0, 3, 1])) # (1 -> 5 -> 2 -> 1) is a cycle
@@ -135,19 +135,14 @@ def connect(i, timestep):
135135

136136
def chuliu_edmonds(s):
137137
r"""
138-
ChuLiu/Edmonds algorithm for non-projective decoding.
138+
ChuLiu/Edmonds algorithm for non-projective decoding :cite:`mcdonald-etal-2005-non`.
139139
140140
Some code is borrowed from `tdozat's implementation`_.
141-
Descriptions of notations and formulas can be found in
142-
`Non-projective Dependency Parsing using Spanning Tree Algorithms`_.
141+
Descriptions of notations and formulas can be found in :cite:`mcdonald-etal-2005-non`.
143142
144143
Notes:
145144
The algorithm does not guarantee to parse a single-root tree.
146145
147-
References:
148-
- Ryan McDonald, Fernando Pereira, Kiril Ribarov and Jan Hajic. 2005.
149-
`Non-projective Dependency Parsing using Spanning Tree Algorithms`_.
150-
151146
Args:
152147
s (~torch.Tensor): ``[seq_len, seq_len]``.
153148
Scores of all dependent-head pairs.
@@ -158,8 +153,6 @@ def chuliu_edmonds(s):
158153
159154
.. _tdozat's implementation:
160155
https://github.com/tdozat/Parser-v3
161-
.. _Non-projective Dependency Parsing using Spanning Tree Algorithms:
162-
https://www.aclweb.org/anthology/H05-1066/
163156
"""
164157

165158
s[0, 1:] = float('-inf')
@@ -234,7 +227,7 @@ def contract(s):
234227

235228
def mst(scores, mask, multiroot=False):
236229
r"""
237-
MST algorithm for decoding non-pojective trees.
230+
MST algorithm for decoding non-projective trees.
238231
This is a wrapper for ChuLiu/Edmonds algorithm.
239232
240233
The algorithm first runs ChuLiu/Edmonds to parse a tree and then have a check of multi-roots,
@@ -248,7 +241,7 @@ def mst(scores, mask, multiroot=False):
248241
mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
249242
The mask to avoid parsing over padding tokens.
250243
The first column serving as pseudo words for roots should be ``False``.
251-
muliroot (bool):
244+
multiroot (bool):
252245
Ensures to parse a single-root tree If ``False``.
253246
254247
Returns:
@@ -291,20 +284,18 @@ def mst(scores, mask, multiroot=False):
291284
return pad(preds, total_length=seq_len).to(mask.device)
292285

293286

294-
def eisner(scores, mask):
287+
def eisner(scores, mask, multiroot=False):
295288
r"""
296-
First-order Eisner algorithm for projective decoding.
297-
298-
References:
299-
- Ryan McDonald, Koby Crammer and Fernando Pereira. 2005.
300-
`Online Large-Margin Training of Dependency Parsers`_.
289+
First-order Eisner algorithm for projective decoding :cite:`mcdonald-etal-2005-online`.
301290
302291
Args:
303292
scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
304293
Scores of all dependent-head pairs.
305294
mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
306295
The mask to avoid parsing over padding tokens.
307296
The first column serving as pseudo words for roots should be ``False``.
297+
multiroot (bool):
298+
Ensures to parse a single-root tree If ``False``.
308299
309300
Returns:
310301
~torch.Tensor:
@@ -318,9 +309,6 @@ def eisner(scores, mask):
318309
>>> mask = torch.tensor([[False, True, True, True]])
319310
>>> eisner(scores, mask)
320311
tensor([[0, 2, 0, 2]])
321-
322-
.. _Online Large-Margin Training of Dependency Parsers:
323-
https://www.aclweb.org/anthology/P05-1012/
324312
"""
325313

326314
lens = mask.sum(1)
@@ -357,7 +345,8 @@ def eisner(scores, mask):
357345
cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
358346
cr_span, cr_path = cr.permute(2, 0, 1).max(-1)
359347
s_c.diagonal(w).copy_(cr_span)
360-
s_c[0, w][lens.ne(w)] = float('-inf')
348+
if not multiroot:
349+
s_c[0, w][lens.ne(w)] = float('-inf')
361350
p_c.diagonal(w).copy_(cr_path + starts + 1)
362351

363352
def backtrack(p_i, p_c, heads, i, j, complete):
@@ -384,23 +373,21 @@ def backtrack(p_i, p_c, heads, i, j, complete):
384373
return pad(preds, total_length=seq_len).to(mask.device)
385374

386375

387-
def eisner2o(scores, mask):
376+
def eisner2o(scores, mask, multiroot=False):
388377
r"""
389-
Second-order Eisner algorithm for projective decoding.
378+
Second-order Eisner algorithm for projective decoding :cite:`mcdonald-pereira-2006-online`.
390379
This is an extension of the first-order one that further incorporates sibling scores into tree scoring.
391380
392-
References:
393-
- Ryan McDonald and Fernando Pereira. 2006.
394-
`Online Learning of Approximate Dependency Parsing Algorithms`_.
395-
396381
Args:
397382
scores (~torch.Tensor, ~torch.Tensor):
398-
A tuple of two tensors representing the first-order and second-order scores repectively.
383+
A tuple of two tensors representing the first-order and second-order scores respectively.
399384
The first (``[batch_size, seq_len, seq_len]``) holds scores of all dependent-head pairs.
400385
The second (``[batch_size, seq_len, seq_len, seq_len]``) holds scores of all dependent-head-sibling triples.
401386
mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
402387
The mask to avoid parsing over padding tokens.
403388
The first column serving as pseudo words for roots should be ``False``.
389+
multiroot (bool):
390+
Ensures to parse a single-root tree If ``False``.
404391
405392
Returns:
406393
~torch.Tensor:
@@ -430,9 +417,6 @@ def eisner2o(scores, mask):
430417
>>> mask = torch.tensor([[False, True, True, True]])
431418
>>> eisner2o((s_arc, s_sib), mask)
432419
tensor([[0, 2, 0, 2]])
433-
434-
.. _Online Learning of Approximate Dependency Parsing Algorithms:
435-
https://www.aclweb.org/anthology/E06-1011/
436420
"""
437421

438422
# the end position of each sentence in a batch
@@ -502,8 +486,8 @@ def eisner2o(scores, mask):
502486
cr = stripe(s_i, n, w, (0, 1)) + stripe(s_c, n, w, (1, w), 0)
503487
cr_span, cr_path = cr.permute(2, 0, 1).max(-1)
504488
s_c.diagonal(w).copy_(cr_span)
505-
# disable multi words to modify the root
506-
s_c[0, w][lens.ne(w)] = float('-inf')
489+
if not multiroot:
490+
s_c[0, w][lens.ne(w)] = float('-inf')
507491
p_c.diagonal(w).copy_(cr_path + starts + 1)
508492

509493
def backtrack(p_i, p_s, p_c, heads, i, j, flag):
@@ -541,11 +525,7 @@ def backtrack(p_i, p_s, p_c, heads, i, j, flag):
541525

542526
def cky(scores, mask):
543527
r"""
544-
The implementation of `Cocke-Kasami-Younger`_ (CKY) algorithm to parse constituency trees.
545-
546-
References:
547-
- Yu Zhang, Houquan Zhou and Zhenghua Li. 2020.
548-
`Fast and Accurate Neural CRF Constituency Parsing`_.
528+
The implementation of `Cocke-Kasami-Younger`_ (CKY) algorithm to parse constituency trees :cite:`zhang-etal-2020-fast`.
549529
550530
Args:
551531
scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
@@ -571,41 +551,43 @@ def cky(scores, mask):
571551
572552
.. _Cocke-Kasami-Younger:
573553
https://en.wikipedia.org/wiki/CYK_algorithm
574-
.. _Fast and Accurate Neural CRF Constituency Parsing:
575-
https://www.ijcai.org/Proceedings/2020/560/
576554
"""
577555

578556
lens = mask[:, 0].sum(-1)
579-
scores = scores.permute(1, 2, 0)
580-
seq_len, seq_len, batch_size = scores.shape
557+
scores = scores.permute(1, 2, 3, 0)
558+
seq_len, seq_len, n_labels, batch_size = scores.shape
581559
s = scores.new_zeros(seq_len, seq_len, batch_size)
582-
p = scores.new_zeros(seq_len, seq_len, batch_size).long()
560+
p_s = scores.new_zeros(seq_len, seq_len, batch_size).long()
561+
p_l = scores.new_zeros(seq_len, seq_len, batch_size).long()
583562

584563
for w in range(1, seq_len):
585564
n = seq_len - w
586-
starts = p.new_tensor(range(n)).unsqueeze(0)
565+
starts = p_s.new_tensor(range(n)).unsqueeze(0)
566+
s_l, p = scores.diagonal(w).max(0)
567+
p_l.diagonal(w).copy_(p)
587568

588569
if w == 1:
589-
s.diagonal(w).copy_(scores.diagonal(w))
570+
s.diagonal(w).copy_(s_l)
590571
continue
591572
# [n, w, batch_size]
592-
s_span = stripe(s, n, w-1, (0, 1)) + stripe(s, n, w-1, (1, w), 0)
573+
s_s = stripe(s, n, w-1, (0, 1)) + stripe(s, n, w-1, (1, w), 0)
593574
# [batch_size, n, w]
594-
s_span = s_span.permute(2, 0, 1)
575+
s_s = s_s.permute(2, 0, 1)
595576
# [batch_size, n]
596-
s_span, p_span = s_span.max(-1)
597-
s.diagonal(w).copy_(s_span + scores.diagonal(w))
598-
p.diagonal(w).copy_(p_span + starts + 1)
577+
s_s, p = s_s.max(-1)
578+
s.diagonal(w).copy_(s_s + s_l)
579+
p_s.diagonal(w).copy_(p + starts + 1)
599580

600-
def backtrack(p, i, j):
581+
def backtrack(p_s, p_l, i, j):
601582
if j == i + 1:
602-
return [(i, j)]
603-
split = p[i][j]
604-
ltree = backtrack(p, i, split)
605-
rtree = backtrack(p, split, j)
606-
return [(i, j)] + ltree + rtree
607-
608-
p = p.permute(2, 0, 1).tolist()
609-
trees = [backtrack(p[i], 0, length) for i, length in enumerate(lens.tolist())]
583+
return [(i, j, p_l[i][j])]
584+
split, label = p_s[i][j], p_l[i][j]
585+
ltree = backtrack(p_s, p_l, i, split)
586+
rtree = backtrack(p_s, p_l, split, j)
587+
return [(i, j, label)] + ltree + rtree
588+
589+
p_s = p_s.permute(2, 0, 1).tolist()
590+
p_l = p_l.permute(2, 0, 1).tolist()
591+
trees = [backtrack(p_s[i], p_l[i], 0, length) for i, length in enumerate(lens.tolist())]
610592

611593
return trees

supar/utils/field.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class Field(RawField):
4242
r"""
4343
Defines a datatype together with instructions for converting to :class:`~torch.Tensor`.
4444
:class:`Field` models common text processing datatypes that can be represented by tensors.
45-
It holds a :class:`Vocab` object that defines the set of possible values
45+
It holds a :class:`~supar.utils.vocab.Vocab` object that defines the set of possible values
4646
for elements of the field and their corresponding numerical representations.
4747
The :class:`Field` object also holds other parameters relating to how a datatype
4848
should be numericalized, such as a tokenization method.
@@ -62,7 +62,8 @@ class Field(RawField):
6262
lower (bool):
6363
Whether to lowercase the text in this field. Default: ``False``.
6464
use_vocab (bool):
65-
Whether to use a :class:`Vocab` object. If ``False``, the data in this field should already be numerical.
65+
Whether to use a :class:`~supar.utils.vocab.Vocab` object.
66+
If ``False``, the data in this field should already be numerical.
6667
Default: ``True``.
6768
tokenize (function):
6869
The function used to tokenize strings using this field into sequential examples. Default: ``None``.
@@ -177,12 +178,13 @@ def preprocess(self, sequence):
177178

178179
def build(self, dataset, min_freq=1, embed=None):
179180
r"""
180-
Constructs a :class:`Vocab` object for this field from the dataset.
181+
Constructs a :class:`~supar.utils.vocab.Vocab` object for this field from the dataset.
181182
If the vocabulary has already existed, this function will have no effect.
182183
183184
Args:
184185
dataset (Dataset):
185-
A :class:`Dataset` object. One of the attributes should be named after the name of this field.
186+
A :class:`~supar.utils.data.Dataset` object.
187+
One of the attributes should be named after the name of this field.
186188
min_freq (int):
187189
The minimum frequency needed to include a token in the vocabulary. Default: 1.
188190
embed (Embedding):
@@ -338,7 +340,7 @@ class ChartField(Field):
338340
Examples:
339341
>>> chart = [[ None, 'NP', None, None, 'S|<>', 'S'],
340342
[ None, None, 'VP|<>', None, 'VP', None],
341-
[ None, None, None, 'VP|<>', 'S+VP', None],
343+
[ None, None, None, 'VP|<>', 'S::VP', None],
342344
[ None, None, None, None, 'NP', None],
343345
[ None, None, None, None, None, 'S|<>'],
344346
[ None, None, None, None, None, None]]

0 commit comments

Comments
 (0)