@@ -86,7 +86,7 @@ def tarjan(sequence):
86
86
List of head indices.
87
87
88
88
Yields:
89
- A list of indices that make up a SCC. All self-loops are ignored.
89
+ A list of indices making up a SCC. All self-loops are ignored.
90
90
91
91
Examples:
92
92
>>> next(tarjan([2, 5, 0, 3, 1])) # (1 -> 5 -> 2 -> 1) is a cycle
@@ -135,19 +135,14 @@ def connect(i, timestep):
135
135
136
136
def chuliu_edmonds (s ):
137
137
r"""
138
- ChuLiu/Edmonds algorithm for non-projective decoding.
138
+ ChuLiu/Edmonds algorithm for non-projective decoding :cite:`mcdonald-etal-2005-non` .
139
139
140
140
Some code is borrowed from `tdozat's implementation`_.
141
- Descriptions of notations and formulas can be found in
142
- `Non-projective Dependency Parsing using Spanning Tree Algorithms`_.
141
+ Descriptions of notations and formulas can be found in :cite:`mcdonald-etal-2005-non`.
143
142
144
143
Notes:
145
144
The algorithm does not guarantee to parse a single-root tree.
146
145
147
- References:
148
- - Ryan McDonald, Fernando Pereira, Kiril Ribarov and Jan Hajic. 2005.
149
- `Non-projective Dependency Parsing using Spanning Tree Algorithms`_.
150
-
151
146
Args:
152
147
s (~torch.Tensor): ``[seq_len, seq_len]``.
153
148
Scores of all dependent-head pairs.
@@ -158,8 +153,6 @@ def chuliu_edmonds(s):
158
153
159
154
.. _tdozat's implementation:
160
155
https://github.com/tdozat/Parser-v3
161
- .. _Non-projective Dependency Parsing using Spanning Tree Algorithms:
162
- https://www.aclweb.org/anthology/H05-1066/
163
156
"""
164
157
165
158
s [0 , 1 :] = float ('-inf' )
@@ -234,7 +227,7 @@ def contract(s):
234
227
235
228
def mst (scores , mask , multiroot = False ):
236
229
r"""
237
- MST algorithm for decoding non-pojective trees.
230
+ MST algorithm for decoding non-projective trees.
238
231
This is a wrapper for ChuLiu/Edmonds algorithm.
239
232
240
233
The algorithm first runs ChuLiu/Edmonds to parse a tree and then have a check of multi-roots,
@@ -248,7 +241,7 @@ def mst(scores, mask, multiroot=False):
248
241
mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
249
242
The mask to avoid parsing over padding tokens.
250
243
The first column serving as pseudo words for roots should be ``False``.
251
- muliroot (bool):
244
+ multiroot (bool):
252
245
Ensures to parse a single-root tree If ``False``.
253
246
254
247
Returns:
@@ -291,20 +284,18 @@ def mst(scores, mask, multiroot=False):
291
284
return pad (preds , total_length = seq_len ).to (mask .device )
292
285
293
286
294
- def eisner (scores , mask ):
287
+ def eisner (scores , mask , multiroot = False ):
295
288
r"""
296
- First-order Eisner algorithm for projective decoding.
297
-
298
- References:
299
- - Ryan McDonald, Koby Crammer and Fernando Pereira. 2005.
300
- `Online Large-Margin Training of Dependency Parsers`_.
289
+ First-order Eisner algorithm for projective decoding :cite:`mcdonald-etal-2005-online`.
301
290
302
291
Args:
303
292
scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
304
293
Scores of all dependent-head pairs.
305
294
mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
306
295
The mask to avoid parsing over padding tokens.
307
296
The first column serving as pseudo words for roots should be ``False``.
297
+ multiroot (bool):
298
+ Ensures to parse a single-root tree If ``False``.
308
299
309
300
Returns:
310
301
~torch.Tensor:
@@ -318,9 +309,6 @@ def eisner(scores, mask):
318
309
>>> mask = torch.tensor([[False, True, True, True]])
319
310
>>> eisner(scores, mask)
320
311
tensor([[0, 2, 0, 2]])
321
-
322
- .. _Online Large-Margin Training of Dependency Parsers:
323
- https://www.aclweb.org/anthology/P05-1012/
324
312
"""
325
313
326
314
lens = mask .sum (1 )
@@ -357,7 +345,8 @@ def eisner(scores, mask):
357
345
cr = stripe (s_i , n , w , (0 , 1 )) + stripe (s_c , n , w , (1 , w ), 0 )
358
346
cr_span , cr_path = cr .permute (2 , 0 , 1 ).max (- 1 )
359
347
s_c .diagonal (w ).copy_ (cr_span )
360
- s_c [0 , w ][lens .ne (w )] = float ('-inf' )
348
+ if not multiroot :
349
+ s_c [0 , w ][lens .ne (w )] = float ('-inf' )
361
350
p_c .diagonal (w ).copy_ (cr_path + starts + 1 )
362
351
363
352
def backtrack (p_i , p_c , heads , i , j , complete ):
@@ -384,23 +373,21 @@ def backtrack(p_i, p_c, heads, i, j, complete):
384
373
return pad (preds , total_length = seq_len ).to (mask .device )
385
374
386
375
387
- def eisner2o (scores , mask ):
376
+ def eisner2o (scores , mask , multiroot = False ):
388
377
r"""
389
- Second-order Eisner algorithm for projective decoding.
378
+ Second-order Eisner algorithm for projective decoding :cite:`mcdonald-pereira-2006-online` .
390
379
This is an extension of the first-order one that further incorporates sibling scores into tree scoring.
391
380
392
- References:
393
- - Ryan McDonald and Fernando Pereira. 2006.
394
- `Online Learning of Approximate Dependency Parsing Algorithms`_.
395
-
396
381
Args:
397
382
scores (~torch.Tensor, ~torch.Tensor):
398
- A tuple of two tensors representing the first-order and second-order scores repectively .
383
+ A tuple of two tensors representing the first-order and second-order scores respectively .
399
384
The first (``[batch_size, seq_len, seq_len]``) holds scores of all dependent-head pairs.
400
385
The second (``[batch_size, seq_len, seq_len, seq_len]``) holds scores of all dependent-head-sibling triples.
401
386
mask (~torch.BoolTensor): ``[batch_size, seq_len]``.
402
387
The mask to avoid parsing over padding tokens.
403
388
The first column serving as pseudo words for roots should be ``False``.
389
+ multiroot (bool):
390
+ Ensures to parse a single-root tree If ``False``.
404
391
405
392
Returns:
406
393
~torch.Tensor:
@@ -430,9 +417,6 @@ def eisner2o(scores, mask):
430
417
>>> mask = torch.tensor([[False, True, True, True]])
431
418
>>> eisner2o((s_arc, s_sib), mask)
432
419
tensor([[0, 2, 0, 2]])
433
-
434
- .. _Online Learning of Approximate Dependency Parsing Algorithms:
435
- https://www.aclweb.org/anthology/E06-1011/
436
420
"""
437
421
438
422
# the end position of each sentence in a batch
@@ -502,8 +486,8 @@ def eisner2o(scores, mask):
502
486
cr = stripe (s_i , n , w , (0 , 1 )) + stripe (s_c , n , w , (1 , w ), 0 )
503
487
cr_span , cr_path = cr .permute (2 , 0 , 1 ).max (- 1 )
504
488
s_c .diagonal (w ).copy_ (cr_span )
505
- # disable multi words to modify the root
506
- s_c [0 , w ][lens .ne (w )] = float ('-inf' )
489
+ if not multiroot :
490
+ s_c [0 , w ][lens .ne (w )] = float ('-inf' )
507
491
p_c .diagonal (w ).copy_ (cr_path + starts + 1 )
508
492
509
493
def backtrack (p_i , p_s , p_c , heads , i , j , flag ):
@@ -541,11 +525,7 @@ def backtrack(p_i, p_s, p_c, heads, i, j, flag):
541
525
542
526
def cky (scores , mask ):
543
527
r"""
544
- The implementation of `Cocke-Kasami-Younger`_ (CKY) algorithm to parse constituency trees.
545
-
546
- References:
547
- - Yu Zhang, Houquan Zhou and Zhenghua Li. 2020.
548
- `Fast and Accurate Neural CRF Constituency Parsing`_.
528
+ The implementation of `Cocke-Kasami-Younger`_ (CKY) algorithm to parse constituency trees :cite:`zhang-etal-2020-fast`.
549
529
550
530
Args:
551
531
scores (~torch.Tensor): ``[batch_size, seq_len, seq_len]``.
@@ -571,41 +551,43 @@ def cky(scores, mask):
571
551
572
552
.. _Cocke-Kasami-Younger:
573
553
https://en.wikipedia.org/wiki/CYK_algorithm
574
- .. _Fast and Accurate Neural CRF Constituency Parsing:
575
- https://www.ijcai.org/Proceedings/2020/560/
576
554
"""
577
555
578
556
lens = mask [:, 0 ].sum (- 1 )
579
- scores = scores .permute (1 , 2 , 0 )
580
- seq_len , seq_len , batch_size = scores .shape
557
+ scores = scores .permute (1 , 2 , 3 , 0 )
558
+ seq_len , seq_len , n_labels , batch_size = scores .shape
581
559
s = scores .new_zeros (seq_len , seq_len , batch_size )
582
- p = scores .new_zeros (seq_len , seq_len , batch_size ).long ()
560
+ p_s = scores .new_zeros (seq_len , seq_len , batch_size ).long ()
561
+ p_l = scores .new_zeros (seq_len , seq_len , batch_size ).long ()
583
562
584
563
for w in range (1 , seq_len ):
585
564
n = seq_len - w
586
- starts = p .new_tensor (range (n )).unsqueeze (0 )
565
+ starts = p_s .new_tensor (range (n )).unsqueeze (0 )
566
+ s_l , p = scores .diagonal (w ).max (0 )
567
+ p_l .diagonal (w ).copy_ (p )
587
568
588
569
if w == 1 :
589
- s .diagonal (w ).copy_ (scores . diagonal ( w ) )
570
+ s .diagonal (w ).copy_ (s_l )
590
571
continue
591
572
# [n, w, batch_size]
592
- s_span = stripe (s , n , w - 1 , (0 , 1 )) + stripe (s , n , w - 1 , (1 , w ), 0 )
573
+ s_s = stripe (s , n , w - 1 , (0 , 1 )) + stripe (s , n , w - 1 , (1 , w ), 0 )
593
574
# [batch_size, n, w]
594
- s_span = s_span .permute (2 , 0 , 1 )
575
+ s_s = s_s .permute (2 , 0 , 1 )
595
576
# [batch_size, n]
596
- s_span , p_span = s_span .max (- 1 )
597
- s .diagonal (w ).copy_ (s_span + scores . diagonal ( w ) )
598
- p .diagonal (w ).copy_ (p_span + starts + 1 )
577
+ s_s , p = s_s .max (- 1 )
578
+ s .diagonal (w ).copy_ (s_s + s_l )
579
+ p_s .diagonal (w ).copy_ (p + starts + 1 )
599
580
600
- def backtrack (p , i , j ):
581
+ def backtrack (p_s , p_l , i , j ):
601
582
if j == i + 1 :
602
- return [(i , j )]
603
- split = p [i ][j ]
604
- ltree = backtrack (p , i , split )
605
- rtree = backtrack (p , split , j )
606
- return [(i , j )] + ltree + rtree
607
-
608
- p = p .permute (2 , 0 , 1 ).tolist ()
609
- trees = [backtrack (p [i ], 0 , length ) for i , length in enumerate (lens .tolist ())]
583
+ return [(i , j , p_l [i ][j ])]
584
+ split , label = p_s [i ][j ], p_l [i ][j ]
585
+ ltree = backtrack (p_s , p_l , i , split )
586
+ rtree = backtrack (p_s , p_l , split , j )
587
+ return [(i , j , label )] + ltree + rtree
588
+
589
+ p_s = p_s .permute (2 , 0 , 1 ).tolist ()
590
+ p_l = p_l .permute (2 , 0 , 1 ).tolist ()
591
+ trees = [backtrack (p_s [i ], p_l [i ], 0 , length ) for i , length in enumerate (lens .tolist ())]
610
592
611
593
return trees
0 commit comments