Skip to content

Commit 5903361

Browse files
committed
Data binarizer supported
1 parent 2c0eff6 commit 5903361

File tree

4 files changed

+186
-108
lines changed

4 files changed

+186
-108
lines changed

supar/cmds/cmd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def init(parser):
1717
parser.add_argument('--seed', '-s', default=1, type=int, help='seed for generating random numbers')
1818
parser.add_argument('--threads', '-t', default=16, type=int, help='num of threads')
1919
parser.add_argument('--workers', '-w', default=0, type=int, help='num of processes used for data loading')
20+
parser.add_argument('--cache', action='store_true', help='cache the data for fast loading')
21+
parser.add_argument('--binarize', action='store_true', help='binarize the data first')
2022
args, unknown = parser.parse_known_args()
2123
args, unknown = parser.parse_known_args(unknown, args)
2224
args = Config.load(**vars(args), unknown=unknown)

supar/utils/data.py

Lines changed: 47 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,13 @@
22

33
from __future__ import annotations
44

5+
import os
56
from typing import Dict, Iterable, List, Union
67

78
import torch
89
import torch.distributed as dist
9-
from supar.utils.fn import kmeans
10+
from supar.utils.fn import debinarize, kmeans
11+
from supar.utils.logging import logger
1012
from supar.utils.transform import Batch, Transform
1113
from torch.utils.data import DataLoader
1214

@@ -23,6 +25,13 @@ class Dataset(torch.utils.data.Dataset):
2325
The instance holds a series of loading and processing behaviours with regard to the specific data format.
2426
data (str or Iterable):
2527
A filename or a list of instances that will be passed into :meth:`transform.load`.
28+
cache (bool):
29+
If ``True``, tries to use the previously cached binarized data for fast loading.
30+
In this way, sentences are loaded on-the-fly according to the meta data.
31+
If ``False``, all sentences will be directly loaded into the memory.
32+
Default: ``False``.
33+
binarize (bool):
34+
If ``True``, binarizes the dataset once building it. Only works if ``cache=True``. Default: ``False``.
2635
kwargs (dict):
2736
Together with `data`, kwargs will be passed into :meth:`transform.load` to control the loading behaviour.
2837
@@ -32,18 +41,39 @@ class Dataset(torch.utils.data.Dataset):
3241
sentences (list[Sentence]):
3342
A list of sentences loaded from the data.
3443
Each sentence includes fields obeying the data format defined in ``transform``.
44+
If ``cache=True``, each is a pointer to the sentence stored in the cache file.
3545
"""
3646

3747
def __init__(
3848
self,
3949
transform: Transform,
4050
data: Union[str, Iterable],
51+
cache: bool = False,
52+
binarize: bool = False,
4153
**kwargs
4254
) -> Dataset:
4355
super(Dataset, self).__init__()
4456

4557
self.transform = transform
46-
self.sentences = transform.load(data, **kwargs)
58+
self.data = data
59+
self.cache = cache
60+
self.binarize = binarize
61+
self.kwargs = kwargs
62+
63+
if cache:
64+
if not isinstance(data, str) or not os.path.exists(data):
65+
raise RuntimeError("Only files are allowed in order to load/save the binarized data")
66+
self.fbin = data + '.pt'
67+
if self.binarize or not os.path.exists(self.fbin):
68+
logger.info(f"Seeking to cache the data to {self.fbin} first")
69+
else:
70+
try:
71+
self.sentences = debinarize(self.fbin, meta=True)
72+
except Exception:
73+
raise RuntimeError(f"Error found while debinarizing {self.fbin}, which may have been corrupted. "
74+
"Try re-binarizing it first")
75+
else:
76+
self.sentences = list(transform.load(data, **kwargs))
4777

4878
def __repr__(self):
4979
s = f"{self.__class__.__name__}("
@@ -60,22 +90,16 @@ def __len__(self):
6090
return len(self.sentences)
6191

6292
def __getitem__(self, index):
63-
return self.sentences[index]
93+
return debinarize(self.fbin, self.sentences[index]) if self.cache else self.sentences[index]
6494

6595
def __getattr__(self, name):
66-
if name in self.__dict__:
67-
return self.__dict__[name]
96+
if name not in {f.name for f in self.transform.flattened_fields}:
97+
raise AttributeError
98+
if self.cache:
99+
sentences = self if os.path.exists(self.fbin) else self.transform.load(self.data, **self.kwargs)
100+
return (getattr(sentence, name) for sentence in sentences)
68101
return [getattr(sentence, name) for sentence in self.sentences]
69102

70-
def __setattr__(self, name, value):
71-
if 'sentences' in self.__dict__ and name in self.sentences[0]:
72-
# restore the order of sequences in the buckets
73-
indices = torch.tensor([i for bucket in self.buckets.values() for i in bucket]).argsort()
74-
for index, sentence in zip(indices, self.sentences):
75-
setattr(sentence, name, value[index])
76-
else:
77-
self.__dict__[name] = value
78-
79103
def __getstate__(self):
80104
return self.__dict__
81105

@@ -91,8 +115,16 @@ def build(
91115
n_workers: int = 0,
92116
pin_memory: bool = True
93117
) -> Dataset:
118+
fields = self.transform.flattened_fields
94119
# numericalize all fields
95-
fields = self.transform(self.sentences)
120+
if self.cache:
121+
# if not forced to do binarization and the binarized file already exists, directly load the meta file
122+
if os.path.exists(self.fbin) and not self.binarize:
123+
self.sentences = debinarize(self.fbin, meta=True)
124+
else:
125+
self.sentences = self.transform(self.transform.load(self.data, **self.kwargs), self.fbin)
126+
else:
127+
self.sentences = self.transform(self.sentences)
96128
# NOTE: the final bucket count is roughly equal to n_buckets
97129
self.buckets = dict(zip(*kmeans([len(s.fields[fields[0].name]) for s in self], n_buckets)))
98130
self.loader = DataLoader(dataset=self,

supar/utils/field.py

Lines changed: 51 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
from __future__ import annotations
44

55
from collections import Counter
6-
from typing import Callable, List, Optional
6+
from typing import Callable, Iterable, List, Optional
77

88
import torch
99
from supar.utils.data import Dataset
1010
from supar.utils.embed import Embedding
1111
from supar.utils.fn import pad
12+
from supar.utils.logging import progress_bar
1213
from supar.utils.vocab import Vocab
1314

1415

@@ -36,10 +37,10 @@ def __repr__(self):
3637
def preprocess(self, sequence: List) -> List:
3738
return self.fn(sequence) if self.fn is not None else sequence
3839

39-
def transform(self, sequences: List[List]) -> List[List]:
40-
return [self.preprocess(seq) for seq in sequences]
40+
def transform(self, sequences: Iterable[List]) -> Iterable[List]:
41+
return (self.preprocess(seq) for seq in sequences)
4142

42-
def compose(self, sequences: List[List]) -> List[List]:
43+
def compose(self, sequences: Iterable[List]) -> Iterable[List]:
4344
return sequences
4445

4546

@@ -102,6 +103,8 @@ def __init__(
102103

103104
def __repr__(self):
104105
s, params = f"({self.name}): {self.__class__.__name__}(", []
106+
if hasattr(self, 'vocab'):
107+
params.append(f"vocab_size={len(self.vocab)}")
105108
if self.pad is not None:
106109
params.append(f"pad={self.pad}")
107110
if self.unk is not None:
@@ -114,10 +117,7 @@ def __repr__(self):
114117
params.append(f"lower={self.lower}")
115118
if not self.use_vocab:
116119
params.append(f"use_vocab={self.use_vocab}")
117-
s += ", ".join(params)
118-
s += ")"
119-
120-
return s
120+
return s + ', '.join(params) + ')'
121121

122122
def __getstate__(self):
123123
state = dict(self.__dict__)
@@ -210,9 +210,8 @@ def build(self, dataset: Dataset, min_freq: int = 1, embed: Optional[Embedding]
210210

211211
if hasattr(self, 'vocab'):
212212
return
213-
sequences = getattr(dataset, self.name)
214213
counter = Counter(token
215-
for seq in sequences
214+
for seq in progress_bar(getattr(dataset, self.name))
216215
for token in self.preprocess(seq))
217216
self.vocab = Vocab(counter, min_freq, self.specials, self.unk_index)
218217

@@ -231,44 +230,43 @@ def build(self, dataset: Dataset, min_freq: int = 1, embed: Optional[Embedding]
231230
if norm is not None:
232231
self.embed = norm(self.embed)
233232

234-
def transform(self, sequences: List[List[str]]) -> List[torch.Tensor]:
233+
def transform(self, sequences: Iterable[List[str]]) -> Iterable[torch.Tensor]:
235234
r"""
236235
Turns a list of sequences that use this field into tensors.
237236
238237
Each sequence is first preprocessed and then numericalized if needed.
239238
240239
Args:
241-
sequences (list[list[str]]):
240+
sequences (Iterable[list[str]]):
242241
A list of sequences.
243242
244243
Returns:
245244
A list of tensors transformed from the input sequences.
246245
"""
247246

248-
sequences = [self.preprocess(seq) for seq in sequences]
249-
if self.use_vocab:
250-
sequences = [self.vocab[seq] for seq in sequences]
251-
if self.bos:
252-
sequences = [[self.bos_index] + seq for seq in sequences]
253-
if self.eos:
254-
sequences = [seq + [self.eos_index] for seq in sequences]
255-
sequences = [torch.tensor(seq) for seq in sequences]
256-
257-
return sequences
258-
259-
def compose(self, sequences: List[torch.Tensor]) -> torch.Tensor:
247+
for seq in sequences:
248+
seq = self.preprocess(seq)
249+
if self.use_vocab:
250+
seq = self.vocab[seq]
251+
if self.bos:
252+
seq = [self.bos_index] + seq
253+
if self.eos:
254+
seq = seq + [self.eos_index]
255+
yield torch.tensor(seq)
256+
257+
def compose(self, batch: List[torch.Tensor]) -> torch.Tensor:
260258
r"""
261259
Composes a batch of sequences into a padded tensor.
262260
263261
Args:
264-
sequences (list[~torch.Tensor]):
262+
batch (list[~torch.Tensor]):
265263
A list of tensors.
266264
267265
Returns:
268266
A padded tensor converted to proper device.
269267
"""
270268

271-
return pad(sequences, self.pad_index).to(self.device)
269+
return pad(batch, self.pad_index).to(self.device)
272270

273271

274272
class SubwordField(Field):
@@ -295,7 +293,7 @@ class SubwordField(Field):
295293
fix_len=20,
296294
tokenize=tokenizer.tokenize)
297295
>>> field.vocab = tokenizer.get_vocab() # no need to re-build the vocab
298-
>>> field.transform([['This', 'field', 'performs', 'token-level', 'tokenization']])[0]
296+
>>> next(field.transform([['This', 'field', 'performs', 'token-level', 'tokenization']]))
299297
tensor([[ 101, 0, 0],
300298
[ 1188, 0, 0],
301299
[ 1768, 0, 0],
@@ -312,9 +310,8 @@ def __init__(self, *args, **kwargs):
312310
def build(self, dataset: Dataset, min_freq: int = 1, embed: Optional[Embedding] = None, norm: Callable = None) -> None:
313311
if hasattr(self, 'vocab'):
314312
return
315-
sequences = getattr(dataset, self.name)
316313
counter = Counter(piece
317-
for seq in sequences
314+
for seq in progress_bar(getattr(dataset, self.name))
318315
for token in seq
319316
for piece in self.preprocess(token))
320317
self.vocab = Vocab(counter, min_freq, self.specials, self.unk_index)
@@ -334,23 +331,19 @@ def build(self, dataset: Dataset, min_freq: int = 1, embed: Optional[Embedding]
334331
if norm is not None:
335332
self.embed = norm(self.embed)
336333

337-
def transform(self, sequences: List[List[str]]) -> List[torch.Tensor]:
338-
sequences = [[self.preprocess(token) for token in seq]
339-
for seq in sequences]
340-
if self.fix_len <= 0:
341-
self.fix_len = max(len(token) for seq in sequences for token in seq)
342-
if self.use_vocab:
343-
sequences = [[[self.vocab[i] if i in self.vocab else self.unk_index for i in token] if token else [self.unk_index]
344-
for token in seq] for seq in sequences]
345-
if self.bos:
346-
sequences = [[[self.bos_index]] + seq for seq in sequences]
347-
if self.eos:
348-
sequences = [seq + [[self.eos_index]] for seq in sequences]
349-
lens = [min(self.fix_len, max(len(ids) for ids in seq)) for seq in sequences]
350-
sequences = [pad([torch.tensor(ids[:i]) for ids in seq], self.pad_index, i)
351-
for i, seq in zip(lens, sequences)]
352-
353-
return sequences
334+
def transform(self, sequences: Iterable[List[str]]) -> Iterable[torch.Tensor]:
335+
for seq in sequences:
336+
seq = [self.preprocess(token) for token in seq]
337+
if self.use_vocab:
338+
seq = [[self.vocab[i] if i in self.vocab else self.unk_index for i in token] if token else [self.unk_index]
339+
for token in seq]
340+
if self.bos:
341+
seq = [[self.bos_index]] + seq
342+
if self.eos:
343+
seq = seq + [[self.eos_index]]
344+
if self.fix_len > 0:
345+
seq = [ids[:self.fix_len] for ids in seq]
346+
yield pad([torch.tensor(ids) for ids in seq], self.pad_index)
354347

355348

356349
class ChartField(Field):
@@ -364,7 +357,7 @@ class ChartField(Field):
364357
[ None, None, None, None, 'NP', None],
365358
[ None, None, None, None, None, 'S|<>'],
366359
[ None, None, None, None, None, None]]
367-
>>> field.transform([chart])[0]
360+
>>> next(field.transform([chart]))
368361
tensor([[ -1, 37, -1, -1, 107, 79],
369362
[ -1, -1, 120, -1, 112, -1],
370363
[ -1, -1, -1, 120, 86, -1],
@@ -375,19 +368,19 @@ class ChartField(Field):
375368

376369
def build(self, dataset: Dataset, min_freq: int = 1) -> None:
377370
counter = Counter(i
378-
for chart in getattr(dataset, self.name)
371+
for chart in progress_bar(getattr(dataset, self.name))
379372
for row in self.preprocess(chart)
380373
for i in row if i is not None)
381374

382375
self.vocab = Vocab(counter, min_freq, self.specials, self.unk_index)
383376

384-
def transform(self, charts: List[List[List]]) -> List[torch.Tensor]:
385-
charts = [self.preprocess(chart) for chart in charts]
386-
if self.use_vocab:
387-
charts = [[[self.vocab[i] if i is not None else -1 for i in row] for row in chart] for chart in charts]
388-
if self.bos:
389-
charts = [[[self.bos_index]*len(chart[0])] + chart for chart in charts]
390-
if self.eos:
391-
charts = [chart + [[self.eos_index]*len(chart[0])] for chart in charts]
392-
charts = [torch.tensor(chart) for chart in charts]
393-
return charts
377+
def transform(self, charts: Iterable[List[List]]) -> Iterable[torch.Tensor]:
378+
for chart in charts:
379+
chart = self.preprocess(chart)
380+
if self.use_vocab:
381+
chart = [[self.vocab[i] if i is not None else -1 for i in row] for row in chart]
382+
if self.bos:
383+
chart = [[self.bos_index]*len(chart[0])] + chart
384+
if self.eos:
385+
chart = chart + [[self.eos_index]*len(chart[0])]
386+
yield torch.tensor(chart)

0 commit comments

Comments
 (0)