Skip to content

Commit ba18616

Browse files
committed
Add support for making prediction on huge files
1 parent 8decdb1 commit ba18616

File tree

5 files changed

+70
-31
lines changed

5 files changed

+70
-31
lines changed

supar/parsers/const.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, mbr=True,
9999

100100
return super().evaluate(**Config().update(locals()))
101101

102-
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, mbr=True,
102+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, cache=False, mbr=True,
103103
verbose=True, **kwargs):
104104
r"""
105105
Args:
@@ -121,6 +121,8 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
121121
The number of tokens in each batch. Default: 5000.
122122
prob (bool):
123123
If ``True``, outputs the probabilities. Default: ``False``.
124+
cache (bool):
125+
If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
124126
mbr (bool):
125127
If ``True``, performs MBR decoding. Default: ``True``.
126128
verbose (bool):
@@ -129,7 +131,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
129131
A dict holding unconsumed arguments for updating prediction configs.
130132
131133
Returns:
132-
A :class:`~supar.utils.Dataset` object that stores the predicted results.
134+
A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None``.
133135
"""
134136

135137
return super().predict(**Config().update(locals()))
@@ -227,6 +229,7 @@ def _predict(self, loader):
227229
for tree, chart in zip(trees, chart_preds)]
228230
if self.args.prob:
229231
batch.probs = [prob[:i-1, 1:i].cpu() for i, prob in zip(lens, s_span)]
232+
yield from batch.sentences
230233

231234
@classmethod
232235
def build(cls, path, min_freq=2, fix_len=20, **kwargs):
@@ -398,7 +401,8 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000,
398401

399402
return super().evaluate(**Config().update(locals()))
400403

401-
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, verbose=True, **kwargs):
404+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, cache=False,
405+
verbose=True, **kwargs):
402406
r"""
403407
Args:
404408
data (str or Iterable):
@@ -419,6 +423,8 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
419423
The number of tokens in each batch. Default: 5000.
420424
prob (bool):
421425
If ``True``, outputs the probabilities. Default: ``False``.
426+
cache (bool):
427+
If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
422428
mbr (bool):
423429
If ``True``, performs MBR decoding. Default: ``True``.
424430
verbose (bool):
@@ -427,7 +433,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
427433
A dict holding unconsumed arguments for updating prediction configs.
428434
429435
Returns:
430-
A :class:`~supar.utils.Dataset` object that stores the predicted results.
436+
A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None``.
431437
"""
432438

433439
return super().predict(**Config().update(locals()))
@@ -525,3 +531,4 @@ def _predict(self, loader):
525531
for tree, chart in zip(trees, chart_preds)]
526532
if self.args.prob:
527533
batch.probs = [prob[:i-1, 1:i].cpu() for i, prob in zip(lens, s_span)]
534+
yield from batch.sentences

supar/parsers/dep.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000,
9494

9595
return super().evaluate(**Config().update(locals()))
9696

97-
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False,
97+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, cache=False,
9898
tree=True, proj=False, verbose=True, **kwargs):
9999
r"""
100100
Args:
@@ -116,6 +116,8 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
116116
The number of tokens in each batch. Default: 5000.
117117
prob (bool):
118118
If ``True``, outputs the probabilities. Default: ``False``.
119+
cache (bool):
120+
If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
119121
tree (bool):
120122
If ``True``, ensures to output well-formed trees. Default: ``False``.
121123
proj (bool):
@@ -126,7 +128,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
126128
A dict holding unconsumed arguments for updating prediction configs.
127129
128130
Returns:
129-
A :class:`~supar.utils.Dataset` object that stores the predicted results.
131+
A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None``.
130132
"""
131133

132134
return super().predict(**Config().update(locals()))
@@ -233,6 +235,7 @@ def _predict(self, loader):
233235
batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)]
234236
if self.args.prob:
235237
batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.softmax(-1).unbind())]
238+
yield from batch.sentences
236239

237240
@classmethod
238241
def build(cls, path, min_freq=2, fix_len=20, **kwargs):
@@ -408,7 +411,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, punct=False,
408411

409412
return super().evaluate(**Config().update(locals()))
410413

411-
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False,
414+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, cache=False,
412415
mbr=True, tree=True, proj=True, verbose=True, **kwargs):
413416
r"""
414417
Args:
@@ -430,6 +433,8 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
430433
The number of tokens in each batch. Default: 5000.
431434
prob (bool):
432435
If ``True``, outputs the probabilities. Default: ``False``.
436+
cache (bool):
437+
If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
433438
mbr (bool):
434439
If ``True``, performs MBR decoding. Default: ``True``.
435440
tree (bool):
@@ -442,7 +447,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
442447
A dict holding unconsumed arguments for updating prediction configs.
443448
444449
Returns:
445-
A :class:`~supar.utils.Dataset` object that stores the predicted results.
450+
A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None``.
446451
"""
447452

448453
return super().predict(**Config().update(locals()))
@@ -553,6 +558,7 @@ def _predict(self, loader):
553558
if self.args.prob:
554559
arc_probs = s_arc if self.args.mbr else s_arc.softmax(-1)
555560
batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())]
561+
yield from batch.sentences
556562

557563

558564
class CRF2oDependencyParser(BiaffineDependencyParser):
@@ -631,7 +637,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, punct=False,
631637

632638
return super().evaluate(**Config().update(locals()))
633639

634-
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False,
640+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, cache=False,
635641
mbr=True, tree=True, proj=True, verbose=True, **kwargs):
636642
r"""
637643
Args:
@@ -653,6 +659,8 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
653659
The number of tokens in each batch. Default: 5000.
654660
prob (bool):
655661
If ``True``, outputs the probabilities. Default: ``False``.
662+
cache (bool):
663+
If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
656664
mbr (bool):
657665
If ``True``, performs MBR decoding. Default: ``True``.
658666
tree (bool):
@@ -665,7 +673,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
665673
A dict holding unconsumed arguments for updating prediction configs.
666674
667675
Returns:
668-
A :class:`~supar.utils.Dataset` object that stores the predicted results.
676+
A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None``.
669677
"""
670678

671679
return super().predict(**Config().update(locals()))
@@ -775,6 +783,7 @@ def _predict(self, loader):
775783
if self.args.prob:
776784
arc_probs = s_arc if self.args.mbr else s_arc.softmax(-1)
777785
batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, arc_probs.unbind())]
786+
yield from batch.sentences
778787

779788
@classmethod
780789
def build(cls, path, min_freq=2, fix_len=20, **kwargs):
@@ -945,7 +954,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, punct=False,
945954

946955
return super().evaluate(**Config().update(locals()))
947956

948-
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False,
957+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, cache=False,
949958
tree=True, proj=True, verbose=True, **kwargs):
950959
r"""
951960
Args:
@@ -967,6 +976,8 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
967976
The number of tokens in each batch. Default: 5000.
968977
prob (bool):
969978
If ``True``, outputs the probabilities. Default: ``False``.
979+
cache (bool):
980+
If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
970981
tree (bool):
971982
If ``True``, ensures to output well-formed trees. Default: ``False``.
972983
proj (bool):
@@ -977,7 +988,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
977988
A dict holding unconsumed arguments for updating prediction configs.
978989
979990
Returns:
980-
A :class:`~supar.utils.Dataset` object that stores the predicted results.
991+
A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None``.
981992
"""
982993

983994
return super().predict(**Config().update(locals()))
@@ -1085,3 +1096,4 @@ def _predict(self, loader):
10851096
batch.rels = [self.REL.vocab[i.tolist()] for i in rel_preds[mask].split(lens)]
10861097
if self.args.prob:
10871098
batch.probs = [prob[1:i+1, :i+1].cpu() for i, prob in zip(lens, s_arc.unbind())]
1099+
yield from batch.sentences

supar/parsers/parser.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22

33
import os
44
from datetime import datetime, timedelta
5+
import shutil
56

67
import dill
78
import supar
89
import torch
910
import torch.distributed as dist
1011
from supar.utils import Config, Dataset
12+
import tempfile
1113
from supar.utils.field import Field
1214
from supar.utils.fn import download, get_rng_state, set_rng_state
13-
from supar.utils.logging import init_logger, logger
15+
from supar.utils.logging import init_logger, logger, progress_bar
1416
from supar.utils.metric import Metric
1517
from supar.utils.parallel import DistributedDataParallel as DDP
1618
from supar.utils.parallel import is_master
@@ -128,7 +130,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, **kwargs):
128130

129131
return loss, metric
130132

131-
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, **kwargs):
133+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, cache=False, **kwargs):
132134
args = self.args.update(locals())
133135
init_logger(logger, verbose=args.verbose)
134136

@@ -143,15 +145,30 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
143145

144146
logger.info("Making predictions on the dataset")
145147
start = datetime.now()
146-
self._predict(dataset.loader)
147-
elapsed = datetime.now() - start
148-
149-
if pred is not None and is_master():
150-
logger.info(f"Saving predicted results to {pred}")
151-
self.transform.save(pred, dataset)
148+
with tempfile.TemporaryDirectory() as t:
149+
# we have clustered the sentences by length here to speed up prediction,
150+
# so the order of the yielded sentences can't be guaranteed
151+
for s in self._predict(dataset.loader):
152+
if args.cache:
153+
with open(os.path.join(t, f"{s.index}"), 'w') as f:
154+
f.write(str(s) + '\n')
155+
elapsed = datetime.now() - start
156+
157+
if pred is not None and is_master():
158+
logger.info(f"Saving predicted results to {pred}")
159+
with open(pred, 'w') as f:
160+
# merge all predictions into one single file
161+
if args.cache:
162+
for s in progress_bar(sorted(os.listdir(t), key=lambda x: int(x))):
163+
with open(os.path.join(t, s)) as s:
164+
shutil.copyfileobj(s, f)
165+
else:
166+
for s in progress_bar(dataset):
167+
f.write(str(s) + '\n')
152168
logger.info(f"{elapsed}s elapsed, {len(dataset) / elapsed.total_seconds():.2f} Sents/s")
153169

154-
return dataset
170+
if not cache:
171+
return dataset
155172

156173
def _train(self, loader):
157174
raise NotImplementedError

supar/parsers/sdp.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, verbose=True, **
7575

7676
return super().evaluate(**Config().update(locals()))
7777

78-
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, verbose=True, **kwargs):
78+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, cache=False,
79+
verbose=True, **kwargs):
7980
r"""
8081
Args:
8182
data (str or Iterable):
@@ -96,13 +97,15 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
9697
The number of tokens in each batch. Default: 5000.
9798
prob (bool):
9899
If ``True``, outputs the probabilities. Default: ``False``.
100+
cache (bool):
101+
If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
99102
verbose (bool):
100103
If ``True``, increases the output verbosity. Default: ``True``.
101104
kwargs (dict):
102105
A dict holding unconsumed arguments for updating prediction configs.
103106
104107
Returns:
105-
A :class:`~supar.utils.Dataset` object that stores the predicted results.
108+
A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None``.
106109
"""
107110

108111
return super().predict(**Config().update(locals()))
@@ -201,6 +204,7 @@ def _predict(self, loader):
201204
for i, chart in zip(lens, label_preds)]
202205
if self.args.prob:
203206
batch.probs = [prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.softmax(-1).unbind())]
207+
yield from batch.sentences
204208

205209
@classmethod
206210
def build(cls, path, min_freq=7, fix_len=20, **kwargs):
@@ -360,7 +364,8 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, verbose=True, **
360364

361365
return super().evaluate(**Config().update(locals()))
362366

363-
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, verbose=True, **kwargs):
367+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, cache=False,
368+
verbose=True, **kwargs):
364369
r"""
365370
Args:
366371
data (str or Iterable):
@@ -381,13 +386,15 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
381386
The number of tokens in each batch. Default: 5000.
382387
prob (bool):
383388
If ``True``, outputs the probabilities. Default: ``False``.
389+
cache (bool):
390+
If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
384391
verbose (bool):
385392
If ``True``, increases the output verbosity. Default: ``True``.
386393
kwargs (dict):
387394
A dict holding unconsumed arguments for updating prediction configs.
388395
389396
Returns:
390-
A :class:`~supar.utils.Dataset` object that stores the predicted results.
397+
A :class:`~supar.utils.Dataset` object containing all predictions if ``cache=False``, otherwise ``None``.
391398
"""
392399

393400
return super().predict(**Config().update(locals()))
@@ -487,3 +494,4 @@ def _predict(self, loader):
487494
for i, chart in zip(lens, label_preds)]
488495
if self.args.prob:
489496
batch.probs = [prob[1:i, :i].cpu() for i, prob in zip(lens, s_edge.unbind())]
497+
yield from batch.sentences

supar/utils/transform.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch.distributions.utils import lazy_property
2323

2424
if TYPE_CHECKING:
25-
from supar.utils import Dataset, Field
25+
from supar.utils import Field
2626

2727

2828
class Transform(object):
@@ -129,11 +129,6 @@ def src(self):
129129
def tgt(self):
130130
raise AttributeError
131131

132-
def save(self, path: str, data: Dataset) -> None:
133-
with open(path, 'w') as f:
134-
for i in data:
135-
f.write(str(i) + '\n')
136-
137132

138133
class CoNLL(Transform):
139134
r"""

0 commit comments

Comments
 (0)