Skip to content

Commit a246cb2

Browse files
committed
Support indexing dataset
1 parent 032bc2f commit a246cb2

File tree

2 files changed

+92
-90
lines changed

2 files changed

+92
-90
lines changed

supar/utils/data.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,22 @@
99

1010
class Dataset(torch.utils.data.Dataset):
1111
r"""
12-
Dataset that is compatible with :class:`torch.utils.data.Dataset`.
13-
This serves as a wrapper for manipulating all data fields
14-
with the operating behaviours defined in :class:`Transform`.
12+
Dataset that is compatible with :class:`torch.utils.data.Dataset`, serving as a wrapper for manipulating all data fields
13+
with the operating behaviours defined in :class:`~supar.utils.transform.Transform`.
1514
The data fields of all the instantiated sentences can be accessed as an attribute of the dataset.
1615
1716
Args:
1817
transform (Transform):
19-
An instance of :class:`Transform` and its derivations.
20-
The instance holds a series of loading and processing behaviours with regard to the specfic data format.
18+
An instance of :class:`~supar.utils.transform.Transform` or its derivations.
19+
The instance holds a series of loading and processing behaviours with regard to the specific data format.
2120
data (list[list] or str):
22-
A list of instances or a filename.
23-
This will be passed into :meth:`transform.load`.
21+
A list of instances or a filename that will be passed into :meth:`transform.load`.
2422
kwargs (dict):
25-
Keyword arguments that will be passed into :meth:`transform.load` together with `data`
26-
to control the loading behaviour.
23+
Together with `data`, kwargs will be passed into :meth:`transform.load` to control the loading behaviour.
2724
2825
Attributes:
2926
transform (Transform):
30-
An instance of :class:`Transform`.
27+
An instance of :class:`~supar.utils.transform.Transform`.
3128
sentences (list[Sentence]):
3229
A list of sentences loaded from the data.
3330
Each sentence includes fields obeying the data format defined in ``transform``.
@@ -54,10 +51,7 @@ def __len__(self):
5451
return len(self.sentences)
5552

5653
def __getitem__(self, index):
57-
if not hasattr(self, 'fields'):
58-
raise RuntimeError("The fields are not numericalized. Please build the dataset first.")
59-
for d in self.fields.values():
60-
yield d[index]
54+
return self.sentences[index]
6155

6256
def __getattr__(self, name):
6357
if name in self.__dict__:
@@ -67,9 +61,7 @@ def __getattr__(self, name):
6761
def __setattr__(self, name, value):
6862
if 'sentences' in self.__dict__ and name in self.sentences[0]:
6963
# restore the order of sequences in the buckets
70-
indices = torch.tensor([i
71-
for bucket in self.buckets.values()
72-
for i in bucket]).argsort()
64+
indices = torch.tensor([i for bucket in self.buckets.values() for i in bucket]).argsort()
7365
for index, sentence in zip(indices, self.sentences):
7466
setattr(sentence, name, value[index])
7567
else:
@@ -83,19 +75,17 @@ def __setstate__(self, state):
8375
self.__dict__.update(state)
8476

8577
def collate_fn(self, batch):
86-
return {f: d for f, d in zip(self.fields.keys(), zip(*batch))}
78+
if not hasattr(self, 'fields'):
79+
raise RuntimeError("The fields are not numericalized yet. Please build the dataset first.")
80+
return {f: [s.transformed[f.name] for s in batch] for f in self.fields}
8781

8882
def build(self, batch_size, n_buckets=1, shuffle=False, distributed=False):
8983
# numericalize all fields
9084
self.fields = self.transform(self.sentences)
9185
# NOTE: the final bucket count is roughly equal to n_buckets
92-
self.lengths = [len(i) for i in self.fields[next(iter(self.fields))]]
93-
self.buckets = dict(zip(*kmeans(self.lengths, n_buckets)))
86+
self.buckets = dict(zip(*kmeans([len(s.transformed[self.fields[0].name]) for s in self], n_buckets)))
9487
self.loader = DataLoader(dataset=self,
95-
batch_sampler=Sampler(buckets=self.buckets,
96-
batch_size=batch_size,
97-
shuffle=shuffle,
98-
distributed=distributed),
88+
batch_sampler=Sampler(self.buckets, batch_size, shuffle, distributed),
9989
collate_fn=self.collate_fn)
10090

10191

@@ -109,7 +99,7 @@ def __init__(self, *args, **kwargs):
10999

110100
def __iter__(self):
111101
for batch in super().__iter__():
112-
yield namedtuple('Batch', [f.name for f in batch.keys()])(*[f.compose(d) for f, d in batch.items()])
102+
yield namedtuple('Batch', (f.name for f in batch.keys()))(*[f.compose(d) for f, d in batch.items()])
113103

114104

115105
class Sampler(torch.utils.data.Sampler):
@@ -148,15 +138,14 @@ def __iter__(self):
148138
g.manual_seed(self.epoch)
149139
range_fn = torch.arange
150140
# if `shuffle=True`, shuffle both the buckets and samples in each bucket
151-
# for distributed training, make sure each process genertes the same random sequence at each epoch
141+
# for distributed training, make sure each process generates the same random sequence at each epoch
152142
if self.shuffle:
153143
def range_fn(x):
154144
return torch.randperm(x, generator=g)
155145
total, count = 0, 0
156146
# TODO: more elegant way to deal with uneven data, which we directly discard right now
157147
for i in range_fn(len(self.buckets)).tolist():
158-
split_sizes = [(len(self.buckets[i]) - j - 1) // self.chunks[i] + 1
159-
for j in range(self.chunks[i])]
148+
split_sizes = [(len(self.buckets[i]) - j - 1) // self.chunks[i] + 1 for j in range(self.chunks[i])]
160149
# DON'T use `torch.chunk` which may return wrong number of chunks
161150
for batch in range_fn(len(self.buckets[i])).split(split_sizes):
162151
if count == self.samples:

0 commit comments

Comments
 (0)