3
3
from __future__ import annotations
4
4
5
5
from collections import Counter
6
- from typing import Callable , List , Optional
6
+ from typing import Callable , Iterable , List , Optional
7
7
8
8
import torch
9
9
from supar .utils .data import Dataset
10
10
from supar .utils .embed import Embedding
11
11
from supar .utils .fn import pad
12
+ from supar .utils .logging import progress_bar
12
13
from supar .utils .vocab import Vocab
13
14
14
15
@@ -36,10 +37,10 @@ def __repr__(self):
36
37
def preprocess (self , sequence : List ) -> List :
37
38
return self .fn (sequence ) if self .fn is not None else sequence
38
39
39
- def transform (self , sequences : List [List ]) -> List [List ]:
40
- return [ self .preprocess (seq ) for seq in sequences ]
40
+ def transform (self , sequences : Iterable [List ]) -> Iterable [List ]:
41
+ return ( self .preprocess (seq ) for seq in sequences )
41
42
42
- def compose (self , sequences : List [List ]) -> List [List ]:
43
+ def compose (self , sequences : Iterable [List ]) -> Iterable [List ]:
43
44
return sequences
44
45
45
46
@@ -102,6 +103,8 @@ def __init__(
102
103
103
104
def __repr__ (self ):
104
105
s , params = f"({ self .name } ): { self .__class__ .__name__ } (" , []
106
+ if hasattr (self , 'vocab' ):
107
+ params .append (f"vocab_size={ len (self .vocab )} " )
105
108
if self .pad is not None :
106
109
params .append (f"pad={ self .pad } " )
107
110
if self .unk is not None :
@@ -114,10 +117,7 @@ def __repr__(self):
114
117
params .append (f"lower={ self .lower } " )
115
118
if not self .use_vocab :
116
119
params .append (f"use_vocab={ self .use_vocab } " )
117
- s += ", " .join (params )
118
- s += ")"
119
-
120
- return s
120
+ return s + ', ' .join (params ) + ')'
121
121
122
122
def __getstate__ (self ):
123
123
state = dict (self .__dict__ )
@@ -210,9 +210,8 @@ def build(self, dataset: Dataset, min_freq: int = 1, embed: Optional[Embedding]
210
210
211
211
if hasattr (self , 'vocab' ):
212
212
return
213
- sequences = getattr (dataset , self .name )
214
213
counter = Counter (token
215
- for seq in sequences
214
+ for seq in progress_bar ( getattr ( dataset , self . name ))
216
215
for token in self .preprocess (seq ))
217
216
self .vocab = Vocab (counter , min_freq , self .specials , self .unk_index )
218
217
@@ -231,44 +230,43 @@ def build(self, dataset: Dataset, min_freq: int = 1, embed: Optional[Embedding]
231
230
if norm is not None :
232
231
self .embed = norm (self .embed )
233
232
234
- def transform (self , sequences : List [List [str ]]) -> List [torch .Tensor ]:
233
+ def transform (self , sequences : Iterable [List [str ]]) -> Iterable [torch .Tensor ]:
235
234
r"""
236
235
Turns a list of sequences that use this field into tensors.
237
236
238
237
Each sequence is first preprocessed and then numericalized if needed.
239
238
240
239
Args:
241
- sequences (list [list[str]]):
240
+ sequences (Iterable [list[str]]):
242
241
A list of sequences.
243
242
244
243
Returns:
245
244
A list of tensors transformed from the input sequences.
246
245
"""
247
246
248
- sequences = [self .preprocess (seq ) for seq in sequences ]
249
- if self .use_vocab :
250
- sequences = [self .vocab [seq ] for seq in sequences ]
251
- if self .bos :
252
- sequences = [[self .bos_index ] + seq for seq in sequences ]
253
- if self .eos :
254
- sequences = [seq + [self .eos_index ] for seq in sequences ]
255
- sequences = [torch .tensor (seq ) for seq in sequences ]
256
-
257
- return sequences
258
-
259
- def compose (self , sequences : List [torch .Tensor ]) -> torch .Tensor :
247
+ for seq in sequences :
248
+ seq = self .preprocess (seq )
249
+ if self .use_vocab :
250
+ seq = self .vocab [seq ]
251
+ if self .bos :
252
+ seq = [self .bos_index ] + seq
253
+ if self .eos :
254
+ seq = seq + [self .eos_index ]
255
+ yield torch .tensor (seq )
256
+
257
+ def compose (self , batch : List [torch .Tensor ]) -> torch .Tensor :
260
258
r"""
261
259
Composes a batch of sequences into a padded tensor.
262
260
263
261
Args:
264
- sequences (list[~torch.Tensor]):
262
+ batch (list[~torch.Tensor]):
265
263
A list of tensors.
266
264
267
265
Returns:
268
266
A padded tensor converted to proper device.
269
267
"""
270
268
271
- return pad (sequences , self .pad_index ).to (self .device )
269
+ return pad (batch , self .pad_index ).to (self .device )
272
270
273
271
274
272
class SubwordField (Field ):
@@ -295,7 +293,7 @@ class SubwordField(Field):
295
293
fix_len=20,
296
294
tokenize=tokenizer.tokenize)
297
295
>>> field.vocab = tokenizer.get_vocab() # no need to re-build the vocab
298
- >>> field.transform([['This', 'field', 'performs', 'token-level', 'tokenization']])[0]
296
+ >>> next( field.transform([['This', 'field', 'performs', 'token-level', 'tokenization']]))
299
297
tensor([[ 101, 0, 0],
300
298
[ 1188, 0, 0],
301
299
[ 1768, 0, 0],
@@ -312,9 +310,8 @@ def __init__(self, *args, **kwargs):
312
310
def build (self , dataset : Dataset , min_freq : int = 1 , embed : Optional [Embedding ] = None , norm : Callable = None ) -> None :
313
311
if hasattr (self , 'vocab' ):
314
312
return
315
- sequences = getattr (dataset , self .name )
316
313
counter = Counter (piece
317
- for seq in sequences
314
+ for seq in progress_bar ( getattr ( dataset , self . name ))
318
315
for token in seq
319
316
for piece in self .preprocess (token ))
320
317
self .vocab = Vocab (counter , min_freq , self .specials , self .unk_index )
@@ -334,23 +331,19 @@ def build(self, dataset: Dataset, min_freq: int = 1, embed: Optional[Embedding]
334
331
if norm is not None :
335
332
self .embed = norm (self .embed )
336
333
337
- def transform (self , sequences : List [List [str ]]) -> List [torch .Tensor ]:
338
- sequences = [[self .preprocess (token ) for token in seq ]
339
- for seq in sequences ]
340
- if self .fix_len <= 0 :
341
- self .fix_len = max (len (token ) for seq in sequences for token in seq )
342
- if self .use_vocab :
343
- sequences = [[[self .vocab [i ] if i in self .vocab else self .unk_index for i in token ] if token else [self .unk_index ]
344
- for token in seq ] for seq in sequences ]
345
- if self .bos :
346
- sequences = [[[self .bos_index ]] + seq for seq in sequences ]
347
- if self .eos :
348
- sequences = [seq + [[self .eos_index ]] for seq in sequences ]
349
- lens = [min (self .fix_len , max (len (ids ) for ids in seq )) for seq in sequences ]
350
- sequences = [pad ([torch .tensor (ids [:i ]) for ids in seq ], self .pad_index , i )
351
- for i , seq in zip (lens , sequences )]
352
-
353
- return sequences
334
+ def transform (self , sequences : Iterable [List [str ]]) -> Iterable [torch .Tensor ]:
335
+ for seq in sequences :
336
+ seq = [self .preprocess (token ) for token in seq ]
337
+ if self .use_vocab :
338
+ seq = [[self .vocab [i ] if i in self .vocab else self .unk_index for i in token ] if token else [self .unk_index ]
339
+ for token in seq ]
340
+ if self .bos :
341
+ seq = [[self .bos_index ]] + seq
342
+ if self .eos :
343
+ seq = seq + [[self .eos_index ]]
344
+ if self .fix_len > 0 :
345
+ seq = [ids [:self .fix_len ] for ids in seq ]
346
+ yield pad ([torch .tensor (ids ) for ids in seq ], self .pad_index )
354
347
355
348
356
349
class ChartField (Field ):
@@ -364,7 +357,7 @@ class ChartField(Field):
364
357
[ None, None, None, None, 'NP', None],
365
358
[ None, None, None, None, None, 'S|<>'],
366
359
[ None, None, None, None, None, None]]
367
- >>> field.transform([chart])[0]
360
+ >>> next( field.transform([chart]))
368
361
tensor([[ -1, 37, -1, -1, 107, 79],
369
362
[ -1, -1, 120, -1, 112, -1],
370
363
[ -1, -1, -1, 120, 86, -1],
@@ -375,19 +368,19 @@ class ChartField(Field):
375
368
376
369
def build (self , dataset : Dataset , min_freq : int = 1 ) -> None :
377
370
counter = Counter (i
378
- for chart in getattr (dataset , self .name )
371
+ for chart in progress_bar ( getattr (dataset , self .name ) )
379
372
for row in self .preprocess (chart )
380
373
for i in row if i is not None )
381
374
382
375
self .vocab = Vocab (counter , min_freq , self .specials , self .unk_index )
383
376
384
- def transform (self , charts : List [List [List ]]) -> List [torch .Tensor ]:
385
- charts = [ self . preprocess ( chart ) for chart in charts ]
386
- if self .use_vocab :
387
- charts = [[[ self .vocab [ i ] if i is not None else - 1 for i in row ] for row in chart ] for chart in charts ]
388
- if self .bos :
389
- charts = [[[ self .bos_index ] * len ( chart [ 0 ])] + chart for chart in charts ]
390
- if self .eos :
391
- charts = [ chart + [[ self .eos_index ] * len ( chart [ 0 ])] for chart in charts ]
392
- charts = [ torch . tensor ( chart ) for chart in charts ]
393
- return charts
377
+ def transform (self , charts : Iterable [List [List ]]) -> Iterable [torch .Tensor ]:
378
+ for chart in charts :
379
+ chart = self .preprocess ( chart )
380
+ if self .use_vocab :
381
+ chart = [[ self .vocab [ i ] if i is not None else - 1 for i in row ] for row in chart ]
382
+ if self .bos :
383
+ chart = [[ self .bos_index ] * len ( chart [ 0 ])] + chart
384
+ if self .eos :
385
+ chart = chart + [[ self . eos_index ] * len ( chart [ 0 ]) ]
386
+ yield torch . tensor ( chart )
0 commit comments