Skip to content

Commit 259fd95

Browse files
author
zysite
committed
Apply bucket mechanism to DataLoader
1 parent 0adb1d3 commit 259fd95

File tree

3 files changed

+98
-31
lines changed

3 files changed

+98
-31
lines changed

parser/utils/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# -*- coding: utf-8 -*-
22

3-
from .dataset import TextDataset, collate_fn
4-
from .reader import Corpus, Embedding
3+
from . import data
4+
from .corpus import Corpus
5+
from .embedding import Embedding
56
from .vocab import Vocab
67

7-
8-
__all__ = ['Corpus', 'Embedding', 'TextDataset', 'Vocab', 'collate_fn']
8+
__all__ = ['data', 'Corpus', 'Embedding', 'Vocab']

parser/utils/data.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import torch
4+
from torch.nn.utils.rnn import pad_sequence
5+
from torch.utils.data import DataLoader, Dataset, Sampler
6+
7+
8+
def kmeans(x, k):
9+
x = torch.tensor(x, dtype=torch.float)
10+
# initialize k centroids randomly
11+
c, old = x[torch.randperm(len(x))[:k]], None
12+
# assign labels to each datapoint based on centroids
13+
dists, y = torch.abs_(x.unsqueeze(-1) - c).min(dim=-1)
14+
15+
while old is None or not c.equal(old):
16+
# handle the empty clusters
17+
for i in range(k):
18+
# choose the farthest datapoint from the biggest cluster
19+
# and move that the empty cluster
20+
if not y.eq(i).any():
21+
mask = y.eq(torch.arange(k).unsqueeze(-1))
22+
lens = mask.sum(dim=-1)
23+
biggest = mask[lens.argmax()].nonzero().view(-1)
24+
farthest = dists[biggest].argmax()
25+
y[biggest[farthest]] = i
26+
# update the centroids
27+
c, old = torch.tensor([x[y.eq(i)].mean() for i in range(k)]), c
28+
# re-assign all datapoints to clusters
29+
dists, y = torch.abs_(x.unsqueeze(-1) - c).min(dim=-1)
30+
clusters = [y.eq(i) for i in range(k)]
31+
clusters = [i.nonzero().view(-1).tolist() for i in clusters if i.any()]
32+
centroids = [round(x[i].mean().item()) for i in clusters]
33+
34+
return centroids, clusters
35+
36+
37+
def collate_fn(data):
38+
reprs = (pad_sequence(i, True) for i in zip(*data))
39+
if torch.cuda.is_available():
40+
reprs = (i.cuda() for i in reprs)
41+
42+
return reprs
43+
44+
45+
class TextSampler(Sampler):
46+
47+
def __init__(self, lengths, batch_size, n_buckets, shuffle=False):
48+
self.lengths = lengths
49+
self.batch_size = batch_size
50+
self.shuffle = shuffle
51+
# NOTE: the final bucket count is less than or equal to n_buckets
52+
self.sizes, self.buckets = kmeans(x=lengths, k=n_buckets)
53+
self.chunks = [max(size * len(bucket) // self.batch_size, 1)
54+
for size, bucket in zip(self.sizes, self.buckets)]
55+
56+
def __iter__(self):
57+
# if shuffle, shffule both the buckets and samples in each bucket
58+
range_fn = torch.randperm if self.shuffle else torch.arange
59+
for i in range_fn(len(self.buckets)):
60+
for batch in range_fn(len(self.buckets[i])).chunk(self.chunks[i]):
61+
yield [self.buckets[i][j] for j in batch.tolist()]
62+
63+
def __len__(self):
64+
return sum(self.chunks)
65+
66+
67+
class TextDataset(Dataset):
68+
69+
def __init__(self, items, n_buckets=1):
70+
super(TextDataset, self).__init__()
71+
72+
self.items = items
73+
74+
def __getitem__(self, index):
75+
return tuple(item[index] for item in self.items)
76+
77+
def __len__(self):
78+
return len(self.items[0])
79+
80+
@property
81+
def lengths(self):
82+
return [len(i) for i in self.items[0]]
83+
84+
85+
def batchify(dataset, batch_size, n_buckets=1, shuffle=False):
86+
batch_sampler = TextSampler(lengths=dataset.lengths,
87+
batch_size=batch_size,
88+
n_buckets=n_buckets,
89+
shuffle=shuffle)
90+
loader = DataLoader(dataset=dataset,
91+
batch_sampler=batch_sampler,
92+
collate_fn=collate_fn)
93+
94+
return loader

parser/utils/dataset.py

Lines changed: 0 additions & 27 deletions
This file was deleted.

0 commit comments

Comments
 (0)