Skip to content

Commit 8decdb1

Browse files
committed
Support parsing plain texts
1 parent 5903361 commit 8decdb1

File tree

5 files changed

+60
-25
lines changed

5 files changed

+60
-25
lines changed

supar/parsers/const.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
104104
r"""
105105
Args:
106106
data (str or Iterable):
107-
The data for prediction. Both a filename and a list of instances are allowed.
107+
The data for prediction.
108+
- a filename. If ends with `.txt`, the parser will seek to make predictions line by line from plain texts.
109+
- a list of instances.
108110
pred (str):
109111
If specified, the predicted results will be saved to the file. Default: ``None``.
110112
lang (str):
@@ -400,7 +402,9 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
400402
r"""
401403
Args:
402404
data (str or Iterable):
403-
The data for prediction. Both a filename and a list of instances are allowed.
405+
The data for prediction.
406+
- a filename. If ends with `.txt`, the parser will seek to make predictions line by line from plain texts.
407+
- a list of instances.
404408
pred (str):
405409
If specified, the predicted results will be saved to the file. Default: ``None``.
406410
lang (str):

supar/parsers/dep.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
9999
r"""
100100
Args:
101101
data (str or Iterable):
102-
The data for prediction. Both a filename and a list of instances are allowed.
102+
The data for prediction.
103+
- a filename. If ends with `.txt`, the parser will seek to make predictions line by line from plain texts.
104+
- a list of instances.
103105
pred (str):
104106
If specified, the predicted results will be saved to the file. Default: ``None``.
105107
lang (str):
@@ -411,7 +413,9 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
411413
r"""
412414
Args:
413415
data (str or Iterable):
414-
The data for prediction. Both a filename and a list of instances are allowed.
416+
The data for prediction.
417+
- a filename. If ends with `.txt`, the parser will seek to make predictions line by line from plain texts.
418+
- a list of instances.
415419
pred (str):
416420
If specified, the predicted results will be saved to the file. Default: ``None``.
417421
lang (str):
@@ -632,7 +636,9 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
632636
r"""
633637
Args:
634638
data (str or Iterable):
635-
The data for prediction. Both a filename and a list of instances are allowed.
639+
The data for prediction.
640+
- a filename. If ends with `.txt`, the parser will seek to make predictions line by line from plain texts.
641+
- a list of instances.
636642
pred (str):
637643
If specified, the predicted results will be saved to the file. Default: ``None``.
638644
lang (str):
@@ -944,7 +950,9 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
944950
r"""
945951
Args:
946952
data (str or Iterable):
947-
The data for prediction. Both a filename and a list of instances are allowed.
953+
The data for prediction.
954+
- a filename. If ends with `.txt`, the parser will seek to make predictions line by line from plain texts.
955+
- a list of instances.
948956
pred (str):
949957
If specified, the predicted results will be saved to the file. Default: ``None``.
950958
lang (str):

supar/parsers/parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, **kwargs):
115115

116116
self.transform.train()
117117
logger.info("Loading the data")
118-
dataset = Dataset(self.transform, data)
118+
dataset = Dataset(self.transform, **args)
119119
dataset.build(batch_size, buckets, False, False, workers)
120120
logger.info(f"\n{dataset}")
121121

@@ -137,7 +137,7 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
137137
self.transform.append(Field('probs'))
138138

139139
logger.info("Loading the data")
140-
dataset = Dataset(self.transform, data, lang=lang)
140+
dataset = Dataset(self.transform, **args)
141141
dataset.build(batch_size, buckets, False, False, workers)
142142
logger.info(f"\n{dataset}")
143143

supar/parsers/sdp.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
7979
r"""
8080
Args:
8181
data (str or Iterable):
82-
The data for prediction. Both a filename and a list of instances are allowed.
82+
The data for prediction.
83+
- a filename. If ends with `.txt`, the parser will seek to make predictions line by line from plain texts.
84+
- a list of instances.
8385
pred (str):
8486
If specified, the predicted results will be saved to the file. Default: ``None``.
8587
lang (str):
@@ -362,7 +364,9 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
362364
r"""
363365
Args:
364366
data (str or Iterable):
365-
The data for prediction. Both a filename and a list of instances are allowed.
367+
The data for prediction.
368+
- a filename. If ends with `.txt`, the parser will seek to make predictions line by line from plain texts.
369+
- a list of instances.
366370
pred (str):
367371
If specified, the predicted results will be saved to the file. Default: ``None``.
368372
lang (str):

supar/utils/transform.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from __future__ import annotations
44

55
import os
6+
import re
67
import shutil
78
import tempfile
89
from collections.abc import Iterable
910
from contextlib import contextmanager
11+
from io import StringIO
1012
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
1113

1214
import nltk
@@ -393,23 +395,31 @@ def load(
393395
A list of :class:`CoNLLSentence` instances.
394396
"""
395397

398+
isconll = False
399+
if lang is not None:
400+
tokenizer = Tokenizer(lang)
396401
if isinstance(data, str) and os.path.exists(data):
397-
with open(data, 'r') as f:
398-
lines = [line.strip() for line in f]
402+
f = open(data)
403+
if data.endswith('.txt'):
404+
lines = (i
405+
for s in f
406+
if len(s) > 1
407+
for i in StringIO(self.toconll(s.split() if lang is None else tokenizer(s)) + '\n'))
408+
else:
409+
lines, isconll = f, True
399410
else:
400411
if lang is not None:
401-
tokenizer = Tokenizer(lang)
402-
data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
412+
data = [tokenizer(s) for s in ([data] if isinstance(data, str) else data)]
403413
else:
404414
data = [data] if isinstance(data[0], str) else data
405-
lines = '\n'.join([self.toconll(i) for i in data]).split('\n')
415+
lines = (i for s in data for i in StringIO(self.toconll(s) + '\n'))
406416

407417
index, sentence = 0, []
408-
for line in lines:
418+
for line in progress_bar(lines):
409419
line = line.strip()
410420
if len(line) == 0:
411421
sentence = CoNLLSentence(self, sentence, index)
412-
if proj and not self.isprojective(list(map(int, sentence.arcs))):
422+
if isconll and proj and not self.isprojective(list(map(int, sentence.arcs))):
413423
logger.warning(f"Sentence {index} is not projective. Discarding it!")
414424
elif max_len is not None and len(sentence) >= max_len:
415425
logger.warning(f"Sentence {index} has {len(sentence)} tokens, exceeding {max_len}. Discarding it!")
@@ -492,10 +502,11 @@ def totree(
492502

493503
if isinstance(tokens[0], str):
494504
tokens = [(token, '_') for token in tokens]
495-
mapped = []
505+
mapped, pattern = [], re.compile(f'[{"".join(special_tokens)}]')
496506
for i, (word, pos) in enumerate(tokens):
497-
if word in special_tokens:
498-
tokens[i] = (special_tokens[word], pos)
507+
match = re.search(pattern, word)
508+
if match:
509+
tokens[i] = (pattern.sub(lambda m: special_tokens[m[0]], word), pos)
499510
mapped.append((i, word))
500511
tree = nltk.Tree.fromstring(f"({root} {' '.join([f'( ({pos} {word}))' for word, pos in tokens])})")
501512
for i, word in mapped:
@@ -690,19 +701,27 @@ def load(
690701
A list of :class:`TreeSentence` instances.
691702
"""
692703

704+
if lang is not None:
705+
tokenizer = Tokenizer(lang)
693706
if isinstance(data, str) and os.path.exists(data):
694-
data = open(data, 'r')
707+
if data.endswith('.txt'):
708+
data = (s.split() if lang is None else tokenizer(s) for s in open(data) if len(s) > 1)
709+
else:
710+
data = open(data)
695711
else:
696712
if lang is not None:
697-
tokenizer = Tokenizer(lang)
698713
data = [tokenizer(i) for i in ([data] if isinstance(data, str) else data)]
699714
else:
700715
data = [data] if isinstance(data[0], str) else data
701716

702717
index = 0
703-
for s in data:
704-
tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root)
705-
sentence = TreeSentence(self, tree, index)
718+
for s in progress_bar(data):
719+
try:
720+
tree = nltk.Tree.fromstring(s) if isinstance(s, str) else self.totree(s, self.root)
721+
sentence = TreeSentence(self, tree, index)
722+
except ValueError:
723+
logger.warning(f"Error found while converting Sentence {index} to a tree:\n{s}\nDiscarding it!")
724+
continue
706725
if max_len is not None and len(sentence) >= max_len:
707726
logger.warning(f"Sentence {index} has {len(sentence)} tokens, exceeding {max_len}. Discarding it!")
708727
else:

0 commit comments

Comments
 (0)