Skip to content

Commit 5886b08

Browse files
author
zysite
committed
Fix a typo
1 parent 845bf05 commit 5886b08

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

parser/utils/data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def __init__(self, corpus, fields, n_buckets=1):
4343
value = field.numericalize(getattr(corpus, field.name))
4444
setattr(self, field.name, value)
4545
# NOTE: the final bucket count is roughly equal to n_buckets
46-
self.centroids, self.clusters = kmeans(x=[len(i) for i in corpus],
47-
k=n_buckets)
48-
self.buckets = dict(zip(self.centroids, self.clusters))
46+
self.lengths = [len(i) + sum([bool(field.bos), bool(field.bos)])
47+
for i in corpus]
48+
self.buckets = dict(zip(*kmeans(self.lengths, n_buckets)))
4949

5050
def __getitem__(self, index):
5151
for field in self.fields:
@@ -86,7 +86,7 @@ def __init__(self, buckets, batch_size, shuffle=False):
8686
]
8787

8888
def __iter__(self):
89-
# if shuffle, shffule both the buckets and samples in each bucket
89+
# if shuffle, shuffle both the buckets and samples in each bucket
9090
range_fn = torch.randperm if self.shuffle else torch.arange
9191
for i in range_fn(len(self.buckets)).tolist():
9292
split_sizes = [(len(self.buckets[i]) - j - 1) // self.chunks[i] + 1

0 commit comments

Comments
 (0)