@@ -31,18 +31,20 @@ def __init__(self, *args, **kwargs):
31
31
self .TREE = self .transform .TREE
32
32
self .CHART = self .transform .CHART
33
33
34
- def train (self , train , dev , test , buckets = 32 , batch_size = 5000 , update_steps = 1 ,
34
+ def train (self , train , dev , test , buckets = 32 , workers = 0 , batch_size = 5000 , update_steps = 1 ,
35
35
mbr = True ,
36
36
delete = {'TOP' , 'S1' , '-NONE-' , ',' , ':' , '``' , "''" , '.' , '?' , '!' , '' },
37
37
equal = {'ADVP' : 'PRT' },
38
38
verbose = True ,
39
39
** kwargs ):
40
40
r"""
41
41
Args:
42
- train/dev/test (list[list] or str ):
42
+ train/dev/test (str or Iterable ):
43
43
Filenames of the train/dev/test datasets.
44
44
buckets (int):
45
45
The number of buckets that sentences are assigned to. Default: 32.
46
+ workers (int):
47
+ The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
46
48
batch_size (int):
47
49
The number of tokens in each batch. Default: 5000.
48
50
update_steps (int):
@@ -63,17 +65,19 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
63
65
64
66
return super ().train (** Config ().update (locals ()))
65
67
66
- def evaluate (self , data , buckets = 8 , batch_size = 5000 , mbr = True ,
68
+ def evaluate (self , data , buckets = 8 , workers = 0 , batch_size = 5000 , mbr = True ,
67
69
delete = {'TOP' , 'S1' , '-NONE-' , ',' , ':' , '``' , "''" , '.' , '?' , '!' , '' },
68
70
equal = {'ADVP' : 'PRT' },
69
71
verbose = True ,
70
72
** kwargs ):
71
73
r"""
72
74
Args:
73
- data (str):
74
- The data for evaluation, both list of instances and filename are allowed.
75
+ data (str or Iterable ):
76
+ The data for evaluation. Both a filename and a list of instances are allowed.
75
77
buckets (int):
76
- The number of buckets that sentences are assigned to. Default: 32.
78
+ The number of buckets that sentences are assigned to. Default: 8.
79
+ workers (int):
80
+ The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
77
81
batch_size (int):
78
82
The number of tokens in each batch. Default: 5000.
79
83
mbr (bool):
@@ -95,19 +99,22 @@ def evaluate(self, data, buckets=8, batch_size=5000, mbr=True,
95
99
96
100
return super ().evaluate (** Config ().update (locals ()))
97
101
98
- def predict (self , data , pred = None , lang = None , buckets = 8 , batch_size = 5000 , prob = False , mbr = True , verbose = True , ** kwargs ):
102
+ def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , prob = False , mbr = True ,
103
+ verbose = True , ** kwargs ):
99
104
r"""
100
105
Args:
101
- data (list[list] or str ):
102
- The data for prediction, both a list of instances and filename are allowed.
106
+ data (str or Iterable ):
107
+ The data for prediction. Both a filename and a list of instances are allowed.
103
108
pred (str):
104
109
If specified, the predicted results will be saved to the file. Default: ``None``.
105
110
lang (str):
106
111
Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
107
112
``None`` if tokenization is not required.
108
113
Default: ``None``.
109
114
buckets (int):
110
- The number of buckets that sentences are assigned to. Default: 32.
115
+ The number of buckets that sentences are assigned to. Default: 8.
116
+ workers (int):
117
+ The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
111
118
batch_size (int):
112
119
The number of tokens in each batch. Default: 5000.
113
120
prob (bool):
@@ -159,7 +166,7 @@ def _train(self, loader):
159
166
bar = progress_bar (loader )
160
167
161
168
for i , batch in enumerate (bar , 1 ):
162
- words , * feats , trees , charts = batch
169
+ words , * feats , trees , charts = batch . compose ( self . transform )
163
170
word_mask = words .ne (self .args .pad_index )[:, 1 :]
164
171
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
165
172
mask = (mask .unsqueeze (1 ) & mask .unsqueeze (2 )).triu_ (1 )
@@ -183,7 +190,7 @@ def _evaluate(self, loader):
183
190
total_loss , metric = 0 , SpanMetric ()
184
191
185
192
for batch in loader :
186
- words , * feats , trees , charts = batch
193
+ words , * feats , trees , charts = batch . compose ( self . transform )
187
194
word_mask = words .ne (self .args .pad_index )[:, 1 :]
188
195
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
189
196
mask = (mask .unsqueeze (1 ) & mask .unsqueeze (2 )).triu_ (1 )
@@ -206,7 +213,7 @@ def _predict(self, loader):
206
213
self .model .eval ()
207
214
208
215
for batch in progress_bar (loader ):
209
- words , * feats , trees = batch
216
+ words , * feats , trees = batch . compose ( self . transform )
210
217
word_mask = words .ne (self .args .pad_index )[:, 1 :]
211
218
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
212
219
mask = (mask .unsqueeze (1 ) & mask .unsqueeze (2 )).triu_ (1 )
@@ -326,17 +333,19 @@ class VIConstituencyParser(CRFConstituencyParser):
326
333
NAME = 'vi-constituency'
327
334
MODEL = VIConstituencyModel
328
335
329
- def train (self , train , dev , test , buckets = 32 , batch_size = 5000 , update_steps = 1 ,
336
+ def train (self , train , dev , test , buckets = 32 , workers = 0 , batch_size = 5000 , update_steps = 1 ,
330
337
delete = {'TOP' , 'S1' , '-NONE-' , ',' , ':' , '``' , "''" , '.' , '?' , '!' , '' },
331
338
equal = {'ADVP' : 'PRT' },
332
339
verbose = True ,
333
340
** kwargs ):
334
341
r"""
335
342
Args:
336
- train/dev/test (list[list] or str ):
343
+ train/dev/test (str or Iterable ):
337
344
Filenames of the train/dev/test datasets.
338
345
buckets (int):
339
346
The number of buckets that sentences are assigned to. Default: 32.
347
+ workers (int):
348
+ The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
340
349
batch_size (int):
341
350
The number of tokens in each batch. Default: 5000.
342
351
update_steps (int):
@@ -355,17 +364,19 @@ def train(self, train, dev, test, buckets=32, batch_size=5000, update_steps=1,
355
364
356
365
return super ().train (** Config ().update (locals ()))
357
366
358
- def evaluate (self , data , buckets = 8 , batch_size = 5000 ,
367
+ def evaluate (self , data , buckets = 8 , workers = 0 , batch_size = 5000 ,
359
368
delete = {'TOP' , 'S1' , '-NONE-' , ',' , ':' , '``' , "''" , '.' , '?' , '!' , '' },
360
369
equal = {'ADVP' : 'PRT' },
361
370
verbose = True ,
362
371
** kwargs ):
363
372
r"""
364
373
Args:
365
- data (str):
366
- The data for evaluation, both list of instances and filename are allowed.
374
+ data (str or Iterable ):
375
+ The data for evaluation. Both a filename and a list of instances are allowed.
367
376
buckets (int):
368
- The number of buckets that sentences are assigned to. Default: 32.
377
+ The number of buckets that sentences are assigned to. Default: 8.
378
+ workers (int):
379
+ The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
369
380
batch_size (int):
370
381
The number of tokens in each batch. Default: 5000.
371
382
delete (set[str]):
@@ -385,19 +396,21 @@ def evaluate(self, data, buckets=8, batch_size=5000,
385
396
386
397
return super ().evaluate (** Config ().update (locals ()))
387
398
388
- def predict (self , data , pred = None , lang = None , buckets = 8 , batch_size = 5000 , prob = False , verbose = True , ** kwargs ):
399
+ def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , prob = False , verbose = True , ** kwargs ):
389
400
r"""
390
401
Args:
391
- data (list[list] or str ):
392
- The data for prediction, both a list of instances and filename are allowed.
402
+ data (str or Iterable ):
403
+ The data for prediction. Both a filename and a list of instances are allowed.
393
404
pred (str):
394
405
If specified, the predicted results will be saved to the file. Default: ``None``.
395
406
lang (str):
396
407
Language code (e.g., ``en``) or language name (e.g., ``English``) for the text to tokenize.
397
408
``None`` if tokenization is not required.
398
409
Default: ``None``.
399
410
buckets (int):
400
- The number of buckets that sentences are assigned to. Default: 32.
411
+ The number of buckets that sentences are assigned to. Default: 8.
412
+ workers (int):
413
+ The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
401
414
batch_size (int):
402
415
The number of tokens in each batch. Default: 5000.
403
416
prob (bool):
@@ -449,7 +462,7 @@ def _train(self, loader):
449
462
bar = progress_bar (loader )
450
463
451
464
for i , batch in enumerate (bar , 1 ):
452
- words , * feats , trees , charts = batch
465
+ words , * feats , trees , charts = batch . compose ( self . transform )
453
466
word_mask = words .ne (self .args .pad_index )[:, 1 :]
454
467
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
455
468
mask = (mask .unsqueeze (1 ) & mask .unsqueeze (2 )).triu_ (1 )
@@ -473,7 +486,7 @@ def _evaluate(self, loader):
473
486
total_loss , metric = 0 , SpanMetric ()
474
487
475
488
for batch in loader :
476
- words , * feats , trees , charts = batch
489
+ words , * feats , trees , charts = batch . compose ( self . transform )
477
490
word_mask = words .ne (self .args .pad_index )[:, 1 :]
478
491
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
479
492
mask = (mask .unsqueeze (1 ) & mask .unsqueeze (2 )).triu_ (1 )
@@ -496,7 +509,7 @@ def _predict(self, loader):
496
509
self .model .eval ()
497
510
498
511
for batch in progress_bar (loader ):
499
- words , * feats , trees = batch
512
+ words , * feats , trees = batch . compose ( self . transform )
500
513
word_mask = words .ne (self .args .pad_index )[:, 1 :]
501
514
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
502
515
mask = (mask .unsqueeze (1 ) & mask .unsqueeze (2 )).triu_ (1 )
0 commit comments