13
13
from supar .utils .logging import get_logger , progress_bar
14
14
from supar .utils .metric import SpanMetric
15
15
from supar .utils .transform import Tree
16
+ from torch .cuda .amp import autocast
16
17
17
18
logger = get_logger (__name__ )
18
19
@@ -31,7 +32,7 @@ def __init__(self, *args, **kwargs):
31
32
self .TREE = self .transform .TREE
32
33
self .CHART = self .transform .CHART
33
34
34
- def train (self , train , dev , test , buckets = 32 , workers = 0 , batch_size = 5000 , update_steps = 1 ,
35
+ def train (self , train , dev , test , buckets = 32 , workers = 0 , batch_size = 5000 , update_steps = 1 , amp = False , cache = False ,
35
36
mbr = True ,
36
37
delete = {'TOP' , 'S1' , '-NONE-' , ',' , ':' , '``' , "''" , '.' , '?' , '!' , '' },
37
38
equal = {'ADVP' : 'PRT' },
@@ -47,6 +48,10 @@ def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update
47
48
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
48
49
batch_size (int):
49
50
The number of tokens in each batch. Default: 5000.
51
+ amp (bool):
52
+ Specifies whether to use automatic mixed precision. Default: ``False``.
53
+ cache (bool):
54
+ If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
50
55
update_steps (int):
51
56
Gradient accumulation steps. Default: 1.
52
57
mbr (bool):
@@ -65,7 +70,8 @@ def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update
65
70
66
71
return super ().train (** Config ().update (locals ()))
67
72
68
- def evaluate (self , data , buckets = 8 , workers = 0 , batch_size = 5000 , mbr = True ,
73
+ def evaluate (self , data , buckets = 8 , workers = 0 , batch_size = 5000 , amp = False , cache = False ,
74
+ mbr = True ,
69
75
delete = {'TOP' , 'S1' , '-NONE-' , ',' , ':' , '``' , "''" , '.' , '?' , '!' , '' },
70
76
equal = {'ADVP' : 'PRT' },
71
77
verbose = True ,
@@ -80,6 +86,10 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, mbr=True,
80
86
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
81
87
batch_size (int):
82
88
The number of tokens in each batch. Default: 5000.
89
+ amp (bool):
90
+ Specifies whether to use automatic mixed precision. Default: ``False``.
91
+ cache (bool):
92
+ If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
83
93
mbr (bool):
84
94
If ``True``, performs MBR decoding. Default: ``True``.
85
95
delete (set[str]):
@@ -99,8 +109,8 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, mbr=True,
99
109
100
110
return super ().evaluate (** Config ().update (locals ()))
101
111
102
- def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , prob = False , cache = False , mbr = True ,
103
- verbose = True , ** kwargs ):
112
+ def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , amp = False , cache = False , prob = False ,
113
+ mbr = True , verbose = True , ** kwargs ):
104
114
r"""
105
115
Args:
106
116
data (str or Iterable):
@@ -119,10 +129,12 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
119
129
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
120
130
batch_size (int):
121
131
The number of tokens in each batch. Default: 5000.
132
+ amp (bool):
133
+ Specifies whether to use automatic mixed precision. Default: ``False``.
134
+ cache (bool):
135
+ If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
122
136
prob (bool):
123
137
If ``True``, outputs the probabilities. Default: ``False``.
124
- cache (bool):
125
- If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
126
138
mbr (bool):
127
139
If ``True``, performs MBR decoding. Default: ``True``.
128
140
verbose (bool):
@@ -174,13 +186,16 @@ def _train(self, loader):
174
186
word_mask = words .ne (self .args .pad_index )[:, 1 :]
175
187
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
176
188
mask = (mask .unsqueeze (1 ) & mask .unsqueeze (2 )).triu_ (1 )
177
- s_span , s_label = self .model ( words , feats )
178
- loss , _ = self . model . loss ( s_span , s_label , charts , mask , self .args . mbr )
179
- loss = loss / self .args .update_steps
180
- loss . backward ()
181
- nn . utils . clip_grad_norm_ ( self .model . parameters (), self . args . clip )
189
+ with autocast ( self .args . amp ):
190
+ s_span , s_label = self .model ( words , feats )
191
+ loss , _ = self . model . loss ( s_span , s_label , charts , mask , self .args .mbr )
192
+ loss = loss / self . args . update_steps
193
+ self .scaler . scale ( loss ). backward ( )
182
194
if i % self .args .update_steps == 0 :
183
- self .optimizer .step ()
195
+ self .scaler .unscale_ (self .optimizer )
196
+ nn .utils .clip_grad_norm_ (self .model .parameters (), self .args .clip )
197
+ self .scaler .step (self .optimizer )
198
+ self .scaler .update ()
184
199
self .scheduler .step ()
185
200
self .optimizer .zero_grad ()
186
201
@@ -198,8 +213,9 @@ def _evaluate(self, loader):
198
213
word_mask = words .ne (self .args .pad_index )[:, 1 :]
199
214
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
200
215
mask = (mask .unsqueeze (1 ) & mask .unsqueeze (2 )).triu_ (1 )
201
- s_span , s_label = self .model (words , feats )
202
- loss , s_span = self .model .loss (s_span , s_label , charts , mask , self .args .mbr )
216
+ with autocast (self .args .amp ):
217
+ s_span , s_label = self .model (words , feats )
218
+ loss , s_span = self .model .loss (s_span , s_label , charts , mask , self .args .mbr )
203
219
chart_preds = self .model .decode (s_span , s_label , mask )
204
220
# since the evaluation relies on terminals,
205
221
# the tree should be first built and then factorized
@@ -222,8 +238,9 @@ def _predict(self, loader):
222
238
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
223
239
mask = (mask .unsqueeze (1 ) & mask .unsqueeze (2 )).triu_ (1 )
224
240
lens = mask [:, 0 ].sum (- 1 )
225
- s_span , s_label = self .model (words , feats )
226
- s_span = ConstituencyCRF (s_span , mask [:, 0 ].sum (- 1 )).marginals if self .args .mbr else s_span
241
+ with autocast (self .args .amp ):
242
+ s_span , s_label = self .model (words , feats )
243
+ s_span = ConstituencyCRF (s_span , mask [:, 0 ].sum (- 1 )).marginals if self .args .mbr else s_span
227
244
chart_preds = self .model .decode (s_span , s_label , mask )
228
245
batch .trees = [Tree .build (tree , [(i , j , self .CHART .vocab [label ]) for i , j , label in chart ])
229
246
for tree , chart in zip (trees , chart_preds )]
@@ -338,7 +355,7 @@ class VIConstituencyParser(CRFConstituencyParser):
338
355
NAME = 'vi-constituency'
339
356
MODEL = VIConstituencyModel
340
357
341
- def train (self , train , dev , test , buckets = 32 , workers = 0 , batch_size = 5000 , update_steps = 1 ,
358
+ def train (self , train , dev , test , buckets = 32 , workers = 0 , batch_size = 5000 , update_steps = 1 , amp = False , cache = False ,
342
359
delete = {'TOP' , 'S1' , '-NONE-' , ',' , ':' , '``' , "''" , '.' , '?' , '!' , '' },
343
360
equal = {'ADVP' : 'PRT' },
344
361
verbose = True ,
@@ -353,6 +370,10 @@ def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update
353
370
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
354
371
batch_size (int):
355
372
The number of tokens in each batch. Default: 5000.
373
+ amp (bool):
374
+ Specifies whether to use automatic mixed precision. Default: ``False``.
375
+ cache (bool):
376
+ If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
356
377
update_steps (int):
357
378
Gradient accumulation steps. Default: 1.
358
379
delete (set[str]):
@@ -369,7 +390,7 @@ def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update
369
390
370
391
return super ().train (** Config ().update (locals ()))
371
392
372
- def evaluate (self , data , buckets = 8 , workers = 0 , batch_size = 5000 ,
393
+ def evaluate (self , data , buckets = 8 , workers = 0 , batch_size = 5000 , amp = False , cache = False ,
373
394
delete = {'TOP' , 'S1' , '-NONE-' , ',' , ':' , '``' , "''" , '.' , '?' , '!' , '' },
374
395
equal = {'ADVP' : 'PRT' },
375
396
verbose = True ,
@@ -384,6 +405,10 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000,
384
405
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
385
406
batch_size (int):
386
407
The number of tokens in each batch. Default: 5000.
408
+ amp (bool):
409
+ Specifies whether to use automatic mixed precision. Default: ``False``.
410
+ cache (bool):
411
+ If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
387
412
delete (set[str]):
388
413
A set of labels that will not be taken into consideration during evaluation.
389
414
Default: {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}.
@@ -401,7 +426,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000,
401
426
402
427
return super ().evaluate (** Config ().update (locals ()))
403
428
404
- def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , prob = False , cache = False ,
429
+ def predict (self , data , pred = None , lang = None , buckets = 8 , workers = 0 , batch_size = 5000 , amp = False , cache = False , prob = False ,
405
430
verbose = True , ** kwargs ):
406
431
r"""
407
432
Args:
@@ -421,10 +446,12 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
421
446
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
422
447
batch_size (int):
423
448
The number of tokens in each batch. Default: 5000.
449
+ amp (bool):
450
+ Specifies whether to use automatic mixed precision. Default: ``False``.
451
+ cache (bool):
452
+ If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
424
453
prob (bool):
425
454
If ``True``, outputs the probabilities. Default: ``False``.
426
- cache (bool):
427
- If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
428
455
mbr (bool):
429
456
If ``True``, performs MBR decoding. Default: ``True``.
430
457
verbose (bool):
@@ -476,13 +503,16 @@ def _train(self, loader):
476
503
word_mask = words .ne (self .args .pad_index )[:, 1 :]
477
504
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
478
505
mask = (mask .unsqueeze (1 ) & mask .unsqueeze (2 )).triu_ (1 )
479
- s_span , s_pair , s_label = self .model ( words , feats )
480
- loss , _ = self . model . loss ( s_span , s_pair , s_label , charts , mask )
481
- loss = loss / self .args . update_steps
482
- loss . backward ()
483
- nn . utils . clip_grad_norm_ ( self .model . parameters (), self . args . clip )
506
+ with autocast ( self .args . amp ):
507
+ s_span , s_pair , s_label = self . model ( words , feats )
508
+ loss , _ = self .model . loss ( s_span , s_pair , s_label , charts , mask )
509
+ loss = loss / self . args . update_steps
510
+ self .scaler . scale ( loss ). backward ( )
484
511
if i % self .args .update_steps == 0 :
485
- self .optimizer .step ()
512
+ self .scaler .unscale_ (self .optimizer )
513
+ nn .utils .clip_grad_norm_ (self .model .parameters (), self .args .clip )
514
+ self .scaler .step (self .optimizer )
515
+ self .scaler .update ()
486
516
self .scheduler .step ()
487
517
self .optimizer .zero_grad ()
488
518
@@ -500,8 +530,9 @@ def _evaluate(self, loader):
500
530
word_mask = words .ne (self .args .pad_index )[:, 1 :]
501
531
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
502
532
mask = (mask .unsqueeze (1 ) & mask .unsqueeze (2 )).triu_ (1 )
503
- s_span , s_pair , s_label = self .model (words , feats )
504
- loss , s_span = self .model .loss (s_span , s_pair , s_label , charts , mask )
533
+ with autocast (self .args .amp ):
534
+ s_span , s_pair , s_label = self .model (words , feats )
535
+ loss , s_span = self .model .loss (s_span , s_pair , s_label , charts , mask )
505
536
chart_preds = self .model .decode (s_span , s_label , mask )
506
537
# since the evaluation relies on terminals,
507
538
# the tree should be first built and then factorized
@@ -524,8 +555,9 @@ def _predict(self, loader):
524
555
mask = word_mask if len (words .shape ) < 3 else word_mask .any (- 1 )
525
556
mask = (mask .unsqueeze (1 ) & mask .unsqueeze (2 )).triu_ (1 )
526
557
lens = mask [:, 0 ].sum (- 1 )
527
- s_span , s_pair , s_label = self .model (words , feats )
528
- s_span = self .model .inference ((s_span , s_pair ), mask )
558
+ with autocast (self .args .amp ):
559
+ s_span , s_pair , s_label = self .model (words , feats )
560
+ s_span = self .model .inference ((s_span , s_pair ), mask )
529
561
chart_preds = self .model .decode (s_span , s_label , mask )
530
562
batch .trees = [Tree .build (tree , [(i , j , self .CHART .vocab [label ]) for i , j , label in chart ])
531
563
for tree , chart in zip (trees , chart_preds )]
0 commit comments