@@ -334,15 +334,16 @@ def transform(self, sequences):
334
334
335
335
class ChartField (Field ):
336
336
r"""
337
- Field dealing with constituency trees.
338
-
339
- This field receives sequences of binarized trees factorized in pre-order,
340
- and returns charts filled with labels on each constituent.
337
+ Field dealing with chart inputs.
341
338
342
339
Examples:
343
- >>> sequence = [(0, 5, 'S'), (0, 4, 'S|<>'), (0, 1, 'NP'), (1, 4, 'VP'), (1, 2, 'VP|<>'),
344
- (2, 4, 'S+VP'), (2, 3, 'VP|<>'), (3, 4, 'NP'), (4, 5, 'S|<>')]
345
- >>> field.transform([sequence])[0]
340
+ >>> chart = [[ None, 'NP', None, None, 'S|<>', 'S'],
341
+ [ None, None, 'VP|<>', None, 'VP', None],
342
+ [ None, None, None, 'VP|<>', 'S+VP', None],
343
+ [ None, None, None, None, 'NP', None],
344
+ [ None, None, None, None, None, 'S|<>'],
345
+ [ None, None, None, None, None, None]]
346
+ >>> field.transform([chart])[0]
346
347
tensor([[ -1, 37, -1, -1, 107, 79],
347
348
[ -1, -1, 120, -1, 112, -1],
348
349
[ -1, -1, -1, 120, 86, -1],
@@ -352,19 +353,14 @@ class ChartField(Field):
352
353
"""
353
354
354
355
def build (self , dataset , min_freq = 1 ):
355
- counter = Counter (label
356
- for seq in getattr (dataset , self .name )
357
- for i , j , label in self .preprocess (seq ))
356
+ counter = Counter (i
357
+ for chart in getattr (dataset , self .name )
358
+ for row in self .preprocess (chart )
359
+ for i in row if i is not None )
358
360
359
361
self .vocab = Vocab (counter , min_freq , self .specials , self .unk_index )
360
362
361
- def transform (self , sequences ):
362
- charts = []
363
- for sequence in sequences :
364
- sequence = self .preprocess (sequence )
365
- seq_len = sequence [0 ][1 ] + 1
366
- chart = torch .full ((seq_len , seq_len ), - 1 , dtype = torch .long )
367
- for i , j , label in sequence :
368
- chart [i , j ] = self .vocab [label ]
369
- charts .append (chart )
363
+ def transform (self , charts ):
364
+ charts = [self .preprocess (chart ) for chart in charts ]
365
+ charts = [torch .tensor ([[self .vocab [i ] if i is not None else - 1 for i in row ] for row in chart ]) for chart in charts ]
370
366
return charts
0 commit comments