Skip to content

Commit b500ef9

Browse files
committed
Support auto mixed precision
1 parent 35e3550 commit b500ef9

File tree

5 files changed

+258
-127
lines changed

5 files changed

+258
-127
lines changed

supar/cmds/cmd.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def init(parser):
1919
parser.add_argument('--workers', '-w', default=0, type=int, help='num of processes used for data loading')
2020
parser.add_argument('--cache', action='store_true', help='cache the data for fast loading')
2121
parser.add_argument('--binarize', action='store_true', help='binarize the data first')
22+
parser.add_argument('--amp', action='store_true', help='use automatic mixed precision for parsing')
2223
args, unknown = parser.parse_known_args()
2324
args, unknown = parser.parse_known_args(unknown, args)
2425
args = Config.load(**vars(args), unknown=unknown)

supar/parsers/const.py

Lines changed: 63 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from supar.utils.logging import get_logger, progress_bar
1414
from supar.utils.metric import SpanMetric
1515
from supar.utils.transform import Tree
16+
from torch.cuda.amp import autocast
1617

1718
logger = get_logger(__name__)
1819

@@ -31,7 +32,7 @@ def __init__(self, *args, **kwargs):
3132
self.TREE = self.transform.TREE
3233
self.CHART = self.transform.CHART
3334

34-
def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update_steps=1,
35+
def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update_steps=1, amp=False, cache=False,
3536
mbr=True,
3637
delete={'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
3738
equal={'ADVP': 'PRT'},
@@ -47,6 +48,10 @@ def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update
4748
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
4849
batch_size (int):
4950
The number of tokens in each batch. Default: 5000.
51+
amp (bool):
52+
Specifies whether to use automatic mixed precision. Default: ``False``.
53+
cache (bool):
54+
If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
5055
update_steps (int):
5156
Gradient accumulation steps. Default: 1.
5257
mbr (bool):
@@ -65,7 +70,8 @@ def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update
6570

6671
return super().train(**Config().update(locals()))
6772

68-
def evaluate(self, data, buckets=8, workers=0, batch_size=5000, mbr=True,
73+
def evaluate(self, data, buckets=8, workers=0, batch_size=5000, amp=False, cache=False,
74+
mbr=True,
6975
delete={'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
7076
equal={'ADVP': 'PRT'},
7177
verbose=True,
@@ -80,6 +86,10 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, mbr=True,
8086
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
8187
batch_size (int):
8288
The number of tokens in each batch. Default: 5000.
89+
amp (bool):
90+
Specifies whether to use automatic mixed precision. Default: ``False``.
91+
cache (bool):
92+
If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
8393
mbr (bool):
8494
If ``True``, performs MBR decoding. Default: ``True``.
8595
delete (set[str]):
@@ -99,8 +109,8 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000, mbr=True,
99109

100110
return super().evaluate(**Config().update(locals()))
101111

102-
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, cache=False, mbr=True,
103-
verbose=True, **kwargs):
112+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, amp=False, cache=False, prob=False,
113+
mbr=True, verbose=True, **kwargs):
104114
r"""
105115
Args:
106116
data (str or Iterable):
@@ -119,10 +129,12 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
119129
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
120130
batch_size (int):
121131
The number of tokens in each batch. Default: 5000.
132+
amp (bool):
133+
Specifies whether to use automatic mixed precision. Default: ``False``.
134+
cache (bool):
135+
If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
122136
prob (bool):
123137
If ``True``, outputs the probabilities. Default: ``False``.
124-
cache (bool):
125-
If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
126138
mbr (bool):
127139
If ``True``, performs MBR decoding. Default: ``True``.
128140
verbose (bool):
@@ -174,13 +186,16 @@ def _train(self, loader):
174186
word_mask = words.ne(self.args.pad_index)[:, 1:]
175187
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
176188
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
177-
s_span, s_label = self.model(words, feats)
178-
loss, _ = self.model.loss(s_span, s_label, charts, mask, self.args.mbr)
179-
loss = loss / self.args.update_steps
180-
loss.backward()
181-
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
189+
with autocast(self.args.amp):
190+
s_span, s_label = self.model(words, feats)
191+
loss, _ = self.model.loss(s_span, s_label, charts, mask, self.args.mbr)
192+
loss = loss / self.args.update_steps
193+
self.scaler.scale(loss).backward()
182194
if i % self.args.update_steps == 0:
183-
self.optimizer.step()
195+
self.scaler.unscale_(self.optimizer)
196+
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
197+
self.scaler.step(self.optimizer)
198+
self.scaler.update()
184199
self.scheduler.step()
185200
self.optimizer.zero_grad()
186201

@@ -198,8 +213,9 @@ def _evaluate(self, loader):
198213
word_mask = words.ne(self.args.pad_index)[:, 1:]
199214
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
200215
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
201-
s_span, s_label = self.model(words, feats)
202-
loss, s_span = self.model.loss(s_span, s_label, charts, mask, self.args.mbr)
216+
with autocast(self.args.amp):
217+
s_span, s_label = self.model(words, feats)
218+
loss, s_span = self.model.loss(s_span, s_label, charts, mask, self.args.mbr)
203219
chart_preds = self.model.decode(s_span, s_label, mask)
204220
# since the evaluation relies on terminals,
205221
# the tree should be first built and then factorized
@@ -222,8 +238,9 @@ def _predict(self, loader):
222238
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
223239
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
224240
lens = mask[:, 0].sum(-1)
225-
s_span, s_label = self.model(words, feats)
226-
s_span = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).marginals if self.args.mbr else s_span
241+
with autocast(self.args.amp):
242+
s_span, s_label = self.model(words, feats)
243+
s_span = ConstituencyCRF(s_span, mask[:, 0].sum(-1)).marginals if self.args.mbr else s_span
227244
chart_preds = self.model.decode(s_span, s_label, mask)
228245
batch.trees = [Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart])
229246
for tree, chart in zip(trees, chart_preds)]
@@ -338,7 +355,7 @@ class VIConstituencyParser(CRFConstituencyParser):
338355
NAME = 'vi-constituency'
339356
MODEL = VIConstituencyModel
340357

341-
def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update_steps=1,
358+
def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update_steps=1, amp=False, cache=False,
342359
delete={'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
343360
equal={'ADVP': 'PRT'},
344361
verbose=True,
@@ -353,6 +370,10 @@ def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update
353370
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
354371
batch_size (int):
355372
The number of tokens in each batch. Default: 5000.
373+
amp (bool):
374+
Specifies whether to use automatic mixed precision. Default: ``False``.
375+
cache (bool):
376+
If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
356377
update_steps (int):
357378
Gradient accumulation steps. Default: 1.
358379
delete (set[str]):
@@ -369,7 +390,7 @@ def train(self, train, dev, test, buckets=32, workers=0, batch_size=5000, update
369390

370391
return super().train(**Config().update(locals()))
371392

372-
def evaluate(self, data, buckets=8, workers=0, batch_size=5000,
393+
def evaluate(self, data, buckets=8, workers=0, batch_size=5000, amp=False, cache=False,
373394
delete={'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''},
374395
equal={'ADVP': 'PRT'},
375396
verbose=True,
@@ -384,6 +405,10 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000,
384405
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
385406
batch_size (int):
386407
The number of tokens in each batch. Default: 5000.
408+
amp (bool):
409+
Specifies whether to use automatic mixed precision. Default: ``False``.
410+
cache (bool):
411+
If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
387412
delete (set[str]):
388413
A set of labels that will not be taken into consideration during evaluation.
389414
Default: {'TOP', 'S1', '-NONE-', ',', ':', '``', "''", '.', '?', '!', ''}.
@@ -401,7 +426,7 @@ def evaluate(self, data, buckets=8, workers=0, batch_size=5000,
401426

402427
return super().evaluate(**Config().update(locals()))
403428

404-
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, prob=False, cache=False,
429+
def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5000, amp=False, cache=False, prob=False,
405430
verbose=True, **kwargs):
406431
r"""
407432
Args:
@@ -421,10 +446,12 @@ def predict(self, data, pred=None, lang=None, buckets=8, workers=0, batch_size=5
421446
The number of subprocesses used for data loading. 0 means only the main process. Default: 0.
422447
batch_size (int):
423448
The number of tokens in each batch. Default: 5000.
449+
amp (bool):
450+
Specifies whether to use automatic mixed precision. Default: ``False``.
451+
cache (bool):
452+
If ``True``, caches the data first, suggested for huge files (e.g., > 1M sentences). Default: ``False``.
424453
prob (bool):
425454
If ``True``, outputs the probabilities. Default: ``False``.
426-
cache (bool):
427-
If ``True``, caches the data first, suggested if parsing huge files (e.g., > 1M sentences). Default: ``False``.
428455
mbr (bool):
429456
If ``True``, performs MBR decoding. Default: ``True``.
430457
verbose (bool):
@@ -476,13 +503,16 @@ def _train(self, loader):
476503
word_mask = words.ne(self.args.pad_index)[:, 1:]
477504
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
478505
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
479-
s_span, s_pair, s_label = self.model(words, feats)
480-
loss, _ = self.model.loss(s_span, s_pair, s_label, charts, mask)
481-
loss = loss / self.args.update_steps
482-
loss.backward()
483-
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
506+
with autocast(self.args.amp):
507+
s_span, s_pair, s_label = self.model(words, feats)
508+
loss, _ = self.model.loss(s_span, s_pair, s_label, charts, mask)
509+
loss = loss / self.args.update_steps
510+
self.scaler.scale(loss).backward()
484511
if i % self.args.update_steps == 0:
485-
self.optimizer.step()
512+
self.scaler.unscale_(self.optimizer)
513+
nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip)
514+
self.scaler.step(self.optimizer)
515+
self.scaler.update()
486516
self.scheduler.step()
487517
self.optimizer.zero_grad()
488518

@@ -500,8 +530,9 @@ def _evaluate(self, loader):
500530
word_mask = words.ne(self.args.pad_index)[:, 1:]
501531
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
502532
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
503-
s_span, s_pair, s_label = self.model(words, feats)
504-
loss, s_span = self.model.loss(s_span, s_pair, s_label, charts, mask)
533+
with autocast(self.args.amp):
534+
s_span, s_pair, s_label = self.model(words, feats)
535+
loss, s_span = self.model.loss(s_span, s_pair, s_label, charts, mask)
505536
chart_preds = self.model.decode(s_span, s_label, mask)
506537
# since the evaluation relies on terminals,
507538
# the tree should be first built and then factorized
@@ -524,8 +555,9 @@ def _predict(self, loader):
524555
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
525556
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
526557
lens = mask[:, 0].sum(-1)
527-
s_span, s_pair, s_label = self.model(words, feats)
528-
s_span = self.model.inference((s_span, s_pair), mask)
558+
with autocast(self.args.amp):
559+
s_span, s_pair, s_label = self.model(words, feats)
560+
s_span = self.model.inference((s_span, s_pair), mask)
529561
chart_preds = self.model.decode(s_span, s_label, mask)
530562
batch.trees = [Tree.build(tree, [(i, j, self.CHART.vocab[label]) for i, j, label in chart])
531563
for tree, chart in zip(trees, chart_preds)]

0 commit comments

Comments
 (0)