Skip to content

Commit d6bf471

Browse files
committed
Batch object
1 parent b3c4893 commit d6bf471

File tree

6 files changed

+204
-177
lines changed

6 files changed

+204
-177
lines changed

supar/parsers/const.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ def _train(self, loader):
158158

159159
bar = progress_bar(loader)
160160

161-
for i, (words, *feats, trees, charts) in enumerate(bar, 1):
161+
for i, batch in enumerate(bar, 1):
162+
words, *feats, trees, charts = batch
162163
word_mask = words.ne(self.args.pad_index)[:, 1:]
163164
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
164165
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
@@ -181,7 +182,8 @@ def _evaluate(self, loader):
181182

182183
total_loss, metric = 0, SpanMetric()
183184

184-
for words, *feats, trees, charts in loader:
185+
for batch in loader:
186+
words, *feats, trees, charts = batch
185187
word_mask = words.ne(self.args.pad_index)[:, 1:]
186188
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
187189
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
@@ -204,7 +206,8 @@ def _predict(self, loader):
204206
self.model.eval()
205207

206208
preds = {'trees': [], 'probs': [] if self.args.prob else None}
207-
for words, *feats, trees in progress_bar(loader):
209+
for batch in progress_bar(loader):
210+
words, *feats, trees = batch
208211
word_mask = words.ne(self.args.pad_index)[:, 1:]
209212
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
210213
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
@@ -449,7 +452,8 @@ def _train(self, loader):
449452

450453
bar = progress_bar(loader)
451454

452-
for i, (words, *feats, trees, charts) in enumerate(bar, 1):
455+
for i, batch in enumerate(bar, 1):
456+
words, *feats, trees, charts = batch
453457
word_mask = words.ne(self.args.pad_index)[:, 1:]
454458
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
455459
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
@@ -472,7 +476,8 @@ def _evaluate(self, loader):
472476

473477
total_loss, metric = 0, SpanMetric()
474478

475-
for words, *feats, trees, charts in loader:
479+
for batch in loader:
480+
words, *feats, trees, charts = batch
476481
word_mask = words.ne(self.args.pad_index)[:, 1:]
477482
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
478483
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)
@@ -495,7 +500,8 @@ def _predict(self, loader):
495500
self.model.eval()
496501

497502
preds = {'trees': [], 'probs': [] if self.args.prob else None}
498-
for words, *feats, trees in progress_bar(loader):
503+
for batch in progress_bar(loader):
504+
words, *feats, trees = batch
499505
word_mask = words.ne(self.args.pad_index)[:, 1:]
500506
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
501507
mask = (mask.unsqueeze(1) & mask.unsqueeze(2)).triu_(1)

supar/parsers/dep.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ def _train(self, loader):
155155

156156
bar, metric = progress_bar(loader), AttachmentMetric()
157157

158-
for i, (words, texts, *feats, arcs, rels) in enumerate(bar, 1):
158+
for i, batch in enumerate(bar, 1):
159+
words, texts, *feats, arcs, rels = batch
159160
word_mask = words.ne(self.args.pad_index)
160161
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
161162
# ignore the first token of each sentence
@@ -186,7 +187,8 @@ def _evaluate(self, loader):
186187

187188
total_loss, metric = 0, AttachmentMetric()
188189

189-
for words, texts, *feats, arcs, rels in loader:
190+
for batch in loader:
191+
words, texts, *feats, arcs, rels = batch
190192
word_mask = words.ne(self.args.pad_index)
191193
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
192194
# ignore the first token of each sentence
@@ -210,7 +212,8 @@ def _predict(self, loader):
210212
self.model.eval()
211213

212214
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
213-
for words, texts, *feats in progress_bar(loader):
215+
for batch in progress_bar(loader):
216+
words, texts, *feats = batch
214217
word_mask = words.ne(self.args.pad_index)
215218
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
216219
# ignore the first token of each sentence
@@ -465,7 +468,8 @@ def _train(self, loader):
465468

466469
bar, metric = progress_bar(loader), AttachmentMetric()
467470

468-
for i, (words, texts, *feats, arcs, rels) in enumerate(bar, 1):
471+
for i, batch in enumerate(bar, 1):
472+
words, texts, *feats, arcs, rels = batch
469473
word_mask = words.ne(self.args.pad_index)
470474
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
471475
# ignore the first token of each sentence
@@ -496,7 +500,8 @@ def _evaluate(self, loader):
496500

497501
total_loss, metric = 0, AttachmentMetric()
498502

499-
for words, texts, *feats, arcs, rels in loader:
503+
for batch in loader:
504+
words, texts, *feats, arcs, rels = batch
500505
word_mask = words.ne(self.args.pad_index)
501506
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
502507
# ignore the first token of each sentence
@@ -520,7 +525,8 @@ def _predict(self, loader):
520525
self.model.eval()
521526

522527
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
523-
for words, texts, *feats in progress_bar(loader):
528+
for batch in progress_bar(loader):
529+
words, texts, *feats = batch
524530
word_mask = words.ne(self.args.pad_index)
525531
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
526532
# ignore the first token of each sentence
@@ -681,7 +687,8 @@ def _train(self, loader):
681687

682688
bar, metric = progress_bar(loader), AttachmentMetric()
683689

684-
for i, (words, texts, *feats, arcs, sibs, rels) in enumerate(bar, 1):
690+
for i, batch in enumerate(bar, 1):
691+
words, texts, *feats, arcs, sibs, rels = batch
685692
word_mask = words.ne(self.args.pad_index)
686693
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
687694
# ignore the first token of each sentence
@@ -712,7 +719,8 @@ def _evaluate(self, loader):
712719

713720
total_loss, metric = 0, AttachmentMetric()
714721

715-
for words, texts, *feats, arcs, sibs, rels in loader:
722+
for batch in loader:
723+
words, texts, *feats, arcs, sibs, rels = batch
716724
word_mask = words.ne(self.args.pad_index)
717725
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
718726
# ignore the first token of each sentence
@@ -736,7 +744,8 @@ def _predict(self, loader):
736744
self.model.eval()
737745

738746
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
739-
for words, texts, *feats in progress_bar(loader):
747+
for batch in progress_bar(loader):
748+
words, texts, *feats = batch
740749
word_mask = words.ne(self.args.pad_index)
741750
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
742751
# ignore the first token of each sentence
@@ -987,7 +996,8 @@ def _train(self, loader):
987996

988997
bar, metric = progress_bar(loader), AttachmentMetric()
989998

990-
for i, (words, texts, *feats, arcs, rels) in enumerate(bar, 1):
999+
for i, batch in enumerate(bar, 1):
1000+
words, texts, *feats, arcs, rels = batch
9911001
word_mask = words.ne(self.args.pad_index)
9921002
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
9931003
# ignore the first token of each sentence
@@ -1018,7 +1028,8 @@ def _evaluate(self, loader):
10181028

10191029
total_loss, metric = 0, AttachmentMetric()
10201030

1021-
for words, texts, *feats, arcs, rels in loader:
1031+
for batch in loader:
1032+
words, texts, *feats, arcs, rels = batch
10221033
word_mask = words.ne(self.args.pad_index)
10231034
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
10241035
# ignore the first token of each sentence
@@ -1042,7 +1053,8 @@ def _predict(self, loader):
10421053
self.model.eval()
10431054

10441055
preds = {'arcs': [], 'rels': [], 'probs': [] if self.args.prob else None}
1045-
for words, texts, *feats in progress_bar(loader):
1056+
for batch in progress_bar(loader):
1057+
words, texts, *feats = batch
10461058
word_mask = words.ne(self.args.pad_index)
10471059
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
10481060
# ignore the first token of each sentence

supar/parsers/sdp.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ def _train(self, loader):
132132

133133
bar, metric = progress_bar(loader), ChartMetric()
134134

135-
for i, (words, *feats, labels) in enumerate(bar, 1):
135+
for i, batch in enumerate(bar, 1):
136+
words, *feats, labels = batch
136137
word_mask = words.ne(self.args.pad_index)
137138
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
138139
mask = mask.unsqueeze(1) & mask.unsqueeze(2)
@@ -158,7 +159,8 @@ def _evaluate(self, loader):
158159

159160
total_loss, metric = 0, ChartMetric()
160161

161-
for words, *feats, labels in loader:
162+
for batch in loader:
163+
words, *feats, labels = batch
162164
word_mask = words.ne(self.args.pad_index)
163165
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
164166
mask = mask.unsqueeze(1) & mask.unsqueeze(2)
@@ -178,7 +180,8 @@ def _predict(self, loader):
178180
self.model.eval()
179181

180182
preds = {'labels': [], 'probs': [] if self.args.prob else None}
181-
for words, *feats in progress_bar(loader):
183+
for batch in progress_bar(loader):
184+
words, *feats = batch
182185
word_mask = words.ne(self.args.pad_index)
183186
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
184187
mask = mask.unsqueeze(1) & mask.unsqueeze(2)
@@ -409,7 +412,8 @@ def _train(self, loader):
409412

410413
bar, metric = progress_bar(loader), ChartMetric()
411414

412-
for i, (words, *feats, labels) in enumerate(bar, 1):
415+
for i, batch in enumerate(bar, 1):
416+
words, *feats, labels = batch
413417
word_mask = words.ne(self.args.pad_index)
414418
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
415419
mask = mask.unsqueeze(1) & mask.unsqueeze(2)
@@ -435,7 +439,8 @@ def _evaluate(self, loader):
435439

436440
total_loss, metric = 0, ChartMetric()
437441

438-
for words, *feats, labels in loader:
442+
for batch in loader:
443+
words, *feats, labels = batch
439444
word_mask = words.ne(self.args.pad_index)
440445
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
441446
mask = mask.unsqueeze(1) & mask.unsqueeze(2)
@@ -455,7 +460,8 @@ def _predict(self, loader):
455460
self.model.eval()
456461

457462
preds = {'labels': [], 'probs': [] if self.args.prob else None}
458-
for words, *feats in progress_bar(loader):
463+
for batch in progress_bar(loader):
464+
words, *feats = batch
459465
word_mask = words.ne(self.args.pad_index)
460466
mask = word_mask if len(words.shape) < 3 else word_mask.any(-1)
461467
mask = mask.unsqueeze(1) & mask.unsqueeze(2)

supar/utils/data.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# -*- coding: utf-8 -*-
22

3-
from collections import namedtuple
4-
53
import torch
64
import torch.distributed as dist
75
from supar.utils.alg import kmeans
6+
from supar.utils.transform import Batch
7+
from torch.utils.data import DataLoader
88

99

1010
class Dataset(torch.utils.data.Dataset):
@@ -74,35 +74,17 @@ def __getstate__(self):
7474
def __setstate__(self, state):
7575
self.__dict__.update(state)
7676

77-
def collate_fn(self, batch):
78-
if not hasattr(self, 'fields'):
79-
raise RuntimeError("The fields are not numericalized yet. Please build the dataset first.")
80-
return {f: [s.transformed[f.name] for s in batch] for f in self.fields}
81-
8277
def build(self, batch_size, n_buckets=1, shuffle=False, distributed=False):
8378
# numericalize all fields
84-
self.fields = self.transform(self.sentences)
79+
fields = self.transform(self.sentences)
8580
# NOTE: the final bucket count is roughly equal to n_buckets
86-
self.buckets = dict(zip(*kmeans([len(s.transformed[self.fields[0].name]) for s in self], n_buckets)))
81+
self.buckets = dict(zip(*kmeans([len(s.transformed[fields[0].name]) for s in self], n_buckets)))
8782
self.loader = DataLoader(dataset=self,
8883
batch_sampler=Sampler(self.buckets, batch_size, shuffle, distributed),
89-
collate_fn=self.collate_fn)
84+
collate_fn=lambda x: Batch(x))
9085
return self
9186

9287

93-
class DataLoader(torch.utils.data.DataLoader):
94-
r"""
95-
DataLoader, matching with :class:`Dataset`.
96-
"""
97-
98-
def __init__(self, *args, **kwargs):
99-
super().__init__(*args, **kwargs)
100-
101-
def __iter__(self):
102-
for batch in super().__iter__():
103-
yield namedtuple('Batch', (f.name for f in batch.keys()))(*[f.compose(d) for f, d in batch.items()])
104-
105-
10688
class Sampler(torch.utils.data.Sampler):
10789
r"""
10890
Sampler that supports for bucketization and token-level batchification.

supar/utils/field.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,9 @@ def transform(self, charts):
365365
charts = [self.preprocess(chart) for chart in charts]
366366
if self.use_vocab:
367367
charts = [[[self.vocab[i] if i is not None else -1 for i in row] for row in chart] for chart in charts]
368+
charts = [torch.tensor(chart) for chart in charts]
368369
if self.bos:
369-
charts = [[[self.bos_index]*len(chart[0])] + chart for chart in charts]
370+
charts = [torch.cat((torch.empty_like[:1].fill_(self.bos_index), chart)) for chart in charts]
370371
if self.eos:
371-
charts = [chart + [[self.eos_index]*len(chart[0])] for chart in charts]
372-
charts = [torch.tensor(chart) for chart in charts]
372+
charts = [torch.cat((chart, torch.empty_like[:1].fill_(self.eos_index))) for chart in charts]
373373
return charts

0 commit comments

Comments
 (0)