@@ -94,7 +94,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000,
94
94
95
95
return super ().evaluate (** Config ().update (locals ()))
96
96
97
- def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , prob = False ,
97
+ def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , prob = False , cache = False ,
98
98
tree = True , proj = False , verbose = True , ** kwargs ):
99
99
r"""
100
100
Args:
@@ -116,6 +116,8 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
116
116
The number of tokens in each batch. Default: 5000.
117
117
prob (bool):
118
118
If ``True``, outputs the probabilities. Default: ``False``.
119
+ cache (bool):
120
+ If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
119
121
tree (bool):
120
122
If ``True``, ensures to output well-formed trees. Default: ``False``.
121
123
proj (bool):
@@ -126,7 +128,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
126
128
A dict holding unconsumed arguments for updating prediction configs.
127
129
128
130
Returns:
129
- A :class:`~supar.utils.Dataset` object that stores the predicted results .
131
+ A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None`` .
130
132
"""
131
133
132
134
return super ().predict (** Config ().update (locals ()))
@@ -233,6 +235,7 @@ def _predict(self, loader):
233
235
batch .rels = [self .REL .vocab [i .tolist ()] for i in rel_preds [mask ].split (lens )]
234
236
if self .args .prob :
235
237
batch .probs = [prob [1 :i + 1 , :i + 1 ].cpu () for i , prob in zip (lens , s_arc .softmax (- 1 ).unbind ())]
238
+ yield from batch .sentences
236
239
237
240
@classmethod
238
241
def build (cls , path , min_freq = 2 , fix_len = 20 , ** kwargs ):
@@ -408,7 +411,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, punct=False,
408
411
409
412
return super ().evaluate (** Config ().update (locals ()))
410
413
411
- def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , prob = False ,
414
+ def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , prob = False , cache = False ,
412
415
mbr = True , tree = True , proj = True , verbose = True , ** kwargs ):
413
416
r"""
414
417
Args:
@@ -430,6 +433,8 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
430
433
The number of tokens in each batch. Default: 5000.
431
434
prob (bool):
432
435
If ``True``, outputs the probabilities. Default: ``False``.
436
+ cache (bool):
437
+ If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
433
438
mbr (bool):
434
439
If ``True``, performs MBR decoding. Default: ``True``.
435
440
tree (bool):
@@ -442,7 +447,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
442
447
A dict holding unconsumed arguments for updating prediction configs.
443
448
444
449
Returns:
445
- A :class:`~supar.utils.Dataset` object that stores the predicted results .
450
+ A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None`` .
446
451
"""
447
452
448
453
return super ().predict (** Config ().update (locals ()))
@@ -553,6 +558,7 @@ def _predict(self, loader):
553
558
if self .args .prob :
554
559
arc_probs = s_arc if self .args .mbr else s_arc .softmax (- 1 )
555
560
batch .probs = [prob [1 :i + 1 , :i + 1 ].cpu () for i , prob in zip (lens , arc_probs .unbind ())]
561
+ yield from batch .sentences
556
562
557
563
558
564
class CRF2oDependencyParser (BiaffineDependencyParser ):
@@ -631,7 +637,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, punct=False,
631
637
632
638
return super ().evaluate (** Config ().update (locals ()))
633
639
634
- def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , prob = False ,
640
+ def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , prob = False , cache = False ,
635
641
mbr = True , tree = True , proj = True , verbose = True , ** kwargs ):
636
642
r"""
637
643
Args:
@@ -653,6 +659,8 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
653
659
The number of tokens in each batch. Default: 5000.
654
660
prob (bool):
655
661
If ``True``, outputs the probabilities. Default: ``False``.
662
+ cache (bool):
663
+ If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
656
664
mbr (bool):
657
665
If ``True``, performs MBR decoding. Default: ``True``.
658
666
tree (bool):
@@ -665,7 +673,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
665
673
A dict holding unconsumed arguments for updating prediction configs.
666
674
667
675
Returns:
668
- A :class:`~supar.utils.Dataset` object that stores the predicted results .
676
+ A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None`` .
669
677
"""
670
678
671
679
return super ().predict (** Config ().update (locals ()))
@@ -775,6 +783,7 @@ def _predict(self, loader):
775
783
if self .args .prob :
776
784
arc_probs = s_arc if self .args .mbr else s_arc .softmax (- 1 )
777
785
batch .probs = [prob [1 :i + 1 , :i + 1 ].cpu () for i , prob in zip (lens , arc_probs .unbind ())]
786
+ yield from batch .sentences
778
787
779
788
@classmethod
780
789
def build (cls , path , min_freq = 2 , fix_len = 20 , ** kwargs ):
@@ -945,7 +954,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, punct=False,
945
954
946
955
return super ().evaluate (** Config ().update (locals ()))
947
956
948
- def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , prob = False ,
957
+ def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , prob = False , cache = False ,
949
958
tree = True , proj = True , verbose = True , ** kwargs ):
950
959
r"""
951
960
Args:
@@ -967,6 +976,8 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
967
976
The number of tokens in each batch. Default: 5000.
968
977
prob (bool):
969
978
If ``True``, outputs the probabilities. Default: ``False``.
979
+ cache (bool):
980
+ If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
970
981
tree (bool):
971
982
If ``True``, ensures to output well-formed trees. Default: ``False``.
972
983
proj (bool):
@@ -977,7 +988,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
977
988
A dict holding unconsumed arguments for updating prediction configs.
978
989
979
990
Returns:
980
- A :class:`~supar.utils.Dataset` object that stores the predicted results .
991
+ A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None`` .
981
992
"""
982
993
983
994
return super ().predict (** Config ().update (locals ()))
@@ -1085,3 +1096,4 @@ def _predict(self, loader):
1085
1096
batch .rels = [self .REL .vocab [i .tolist ()] for i in rel_preds [mask ].split (lens )]
1086
1097
if self .args .prob :
1087
1098
batch .probs = [prob [1 :i + 1 , :i + 1 ].cpu () for i , prob in zip (lens , s_arc .unbind ())]
1099
+ yield from batch .sentences
0 commit comments