Skip to content

Commit 7968a87

Browse files
committed
Adapt more args of DataLoader
1 parent 7452af0 commit 7968a87

File tree

7 files changed

+211
-145
lines changed

7 files changed

+211
-145
lines changed

supar/cmds/cmd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@ def init(parser):
1515
parser.add_argument('--conf', '-c', default='', help='path to config file')
1616
parser.add_argument('--device', '-d', default='-1', help='ID of GPU to use')
1717
parser.add_argument('--seed', '-s', default=1, type=int, help='seed for generating random numbers')
18-
parser.add_argument('--threads', '-t', default=16, type=int, help='max num of threads')
18+
parser.add_argument('--threads', '-t', default=16, type=int, help='num of threads')
19+
parser.add_argument('--workers', '-w', default=0, type=int, help='num of processes used for data loading')
1920
args, unknown = parser.parse_known_args()
2021
args, unknown = parser.parse_known_args(unknown, args)
2122
args = Config.load(**vars(args), unknown=unknown)

supar/parsers/const.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,20 @@ def __init__(self, *args, **kwargs):
3131
self.TREE = self.transform.TREE
3232
self.CHART = self.transform.CHART
3333

34-
def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
34+
def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update_steps=1,
3535
mbr=True,
3636
delete={'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
3737
equal={'ADVP': 'PRT'},
3838
verbose=True,
3939
**kwargs):
4040
r"""
4141
Args:
42-
train/dev/test (list[list] or str):
42+
train/dev/test (str or Iterable):
4343
Filenames of the train/dev/test datasets.
4444
buckets (int):
4545
The number of buckets that sentences are assigned to. Default: 32.
46+
workers (int):
47+
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
4648
batch_size (int):
4749
The number of tokens in each batch. Default: 5000.
4850
update_steps (int):
@@ -63,17 +65,19 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
6365

6466
return super().train(**Config().update(locals()))
6567

66-
def evaluate(self, data, buckets=8, batch_size=5000, mbr=True,
68+
def evaluate(self, data, buckets=8, workers=0, batch_size=5000, mbr=True,
6769
delete={'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
6870
equal={'ADVP': 'PRT'},
6971
verbose=True,
7072
**kwargs):
7173
r"""
7274
Args:
73-
data (str):
74-
The data for evaluation, both list of instances and filename are allowed.
75+
data (str or Iterable):
76+
The data for evaluation. Both a filename and a list of instances are allowed.
7577
buckets (int):
76-
The number of buckets that sentences are assigned to. Default: 32.
78+
The number of buckets that sentences are assigned to. Default: 8.
79+
workers (int):
80+
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
7781
batch_size (int):
7882
The number of tokens in each batch. Default: 5000.
7983
mbr (bool):
@@ -95,19 +99,22 @@ def evaluate(self, data, buckets=8, batch_size=5000, mbr=True,
9599

96100
return super().evaluate(**Config().update(locals()))
97101

98-
def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=False, mbr=True, verbose=True, **kwargs):
102+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, mbr=True,
103+
verbose=True, **kwargs):
99104
r"""
100105
Args:
101-
data (list[list] or str):
102-
The data for prediction, both a list of instances and filename are allowed.
106+
data (str or Iterable):
107+
The data for prediction. Both a filename and a list of instances are allowed.
103108
pred (str):
104109
If specified, the predicted results will be saved to the file. Default: ``None``.
105110
lang (str):
106111
Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
107112
``None`` if tokenization is not required.
108113
Default: ``None``.
109114
buckets (int):
110-
The number of buckets that sentences are assigned to. Default: 32.
115+
The number of buckets that sentences are assigned to. Default: 8.
116+
workers (int):
117+
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
111118
batch_size (int):
112119
The number of tokens in each batch. Default: 5000.
113120
prob (bool):
@@ -159,7 +166,7 @@ def _train(self, loader):
159166
bar = progress_bar(loader)
160167

161168
for i, batch in enumerate(bar, 1):
162-
words, *feats, trees, charts = batch
169+
words, *feats, trees, charts = batch.compose(self.transform)
163170
word_mask = words.ne(self.args.pad_index)[:, 1:]
164171
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
165172
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
@@ -183,7 +190,7 @@ def _evaluate(self, loader):
183190
total_loss, metric = 0, SpanMetric()
184191

185192
for batch in loader:
186-
words, *feats, trees, charts = batch
193+
words, *feats, trees, charts = batch.compose(self.transform)
187194
word_mask = words.ne(self.args.pad_index)[:, 1:]
188195
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
189196
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
@@ -206,7 +213,7 @@ def _predict(self, loader):
206213
self.model.eval()
207214

208215
for batch in progress_bar(loader):
209-
words, *feats, trees = batch
216+
words, *feats, trees = batch.compose(self.transform)
210217
word_mask = words.ne(self.args.pad_index)[:, 1:]
211218
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
212219
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
@@ -326,17 +333,19 @@ class VIConstituencyParser(CRFConstituencyParser):
326333
NAME = 'vi-constituency'
327334
MODEL = VIConstituencyModel
328335

329-
def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
336+
def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update_steps=1,
330337
delete={'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
331338
equal={'ADVP': 'PRT'},
332339
verbose=True,
333340
**kwargs):
334341
r"""
335342
Args:
336-
train/dev/test (list[list] or str):
343+
train/dev/test (str or Iterable):
337344
Filenames of the train/dev/test datasets.
338345
buckets (int):
339346
The number of buckets that sentences are assigned to. Default: 32.
347+
workers (int):
348+
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
340349
batch_size (int):
341350
The number of tokens in each batch. Default: 5000.
342351
update_steps (int):
@@ -355,17 +364,19 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
355364

356365
return super().train(**Config().update(locals()))
357366

358-
def evaluate(self, data, buckets=8, batch_size=5000,
367+
def evaluate(self, data, buckets=8, workers=0, batch_size=5000,
359368
delete={'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
360369
equal={'ADVP': 'PRT'},
361370
verbose=True,
362371
**kwargs):
363372
r"""
364373
Args:
365-
data (str):
366-
The data for evaluation, both list of instances and filename are allowed.
374+
data (str or Iterable):
375+
The data for evaluation. Both a filename and a list of instances are allowed.
367376
buckets (int):
368-
The number of buckets that sentences are assigned to. Default: 32.
377+
The number of buckets that sentences are assigned to. Default: 8.
378+
workers (int):
379+
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
369380
batch_size (int):
370381
The number of tokens in each batch. Default: 5000.
371382
delete (set[str]):
@@ -385,19 +396,21 @@ def evaluate(self, data, buckets=8, batch_size=5000,
385396

386397
return super().evaluate(**Config().update(locals()))
387398

388-
def predict(self, data, pred=None, lang=None, buckets=8, batch_size=5000, prob=False, verbose=True, **kwargs):
399+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, verbose=True, **kwargs):
389400
r"""
390401
Args:
391-
data (list[list] or str):
392-
The data for prediction, both a list of instances and filename are allowed.
402+
data (str or Iterable):
403+
The data for prediction. Both a filename and a list of instances are allowed.
393404
pred (str):
394405
If specified, the predicted results will be saved to the file. Default: ``None``.
395406
lang (str):
396407
Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
397408
``None`` if tokenization is not required.
398409
Default: ``None``.
399410
buckets (int):
400-
The number of buckets that sentences are assigned to. Default: 32.
411+
The number of buckets that sentences are assigned to. Default: 8.
412+
workers (int):
413+
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
401414
batch_size (int):
402415
The number of tokens in each batch. Default: 5000.
403416
prob (bool):
@@ -449,7 +462,7 @@ def _train(self, loader):
449462
bar = progress_bar(loader)
450463

451464
for i, batch in enumerate(bar, 1):
452-
words, *feats, trees, charts = batch
465+
words, *feats, trees, charts = batch.compose(self.transform)
453466
word_mask = words.ne(self.args.pad_index)[:, 1:]
454467
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
455468
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
@@ -473,7 +486,7 @@ def _evaluate(self, loader):
473486
total_loss, metric = 0, SpanMetric()
474487

475488
for batch in loader:
476-
words, *feats, trees, charts = batch
489+
words, *feats, trees, charts = batch.compose(self.transform)
477490
word_mask = words.ne(self.args.pad_index)[:, 1:]
478491
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
479492
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
@@ -496,7 +509,7 @@ def _predict(self, loader):
496509
self.model.eval()
497510

498511
for batch in progress_bar(loader):
499-
words, *feats, trees = batch
512+
words, *feats, trees = batch.compose(self.transform)
500513
word_mask = words.ne(self.args.pad_index)[:, 1:]
501514
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
502515
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)

0 commit comments

Comments
 (0)