Skip to content

Commit 207d38e

Browse files
committed
Batch object
1 parent 6812efb commit 207d38e

File tree

2 files changed

+149
-145
lines changed

2 files changed

+149
-145
lines changed

supar/utils/data.py

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# -*- coding: utf-8 -*-
22

3-
from collections import namedtuple
4-
53
import torch
64
import torch.distributed as dist
75
from supar.utils.alg import kmeans
6+
from supar.utils.transform import Batch
7+
from torch.utils.data import DataLoader
88

99

1010
class Dataset(torch.utils.data.Dataset):
@@ -74,32 +74,14 @@ def __getstate__(self):
7474
def __setstate__(self, state):
7575
self.__dict__.update(state)
7676

77-
def collate_fn(self, batch):
78-
if not hasattr(self, 'fields'):
79-
raise RuntimeError("The fields are not numericalized yet. Please build the dataset first.")
80-
return {f: [s.transformed[f.name] for s in batch] for f in self.fields}
81-
8277
def build(self, batch_size, n_buckets=1, shuffle=False, distributed=False):
8378
# numericalize all fields
84-
self.fields = self.transform(self.sentences)
79+
fields = self.transform(self.sentences)
8580
# NOTE: the final bucket count is roughly equal to n_buckets
86-
self.buckets = dict(zip(*kmeans([len(s.transformed[self.fields[0].name]) for s in self], n_buckets)))
81+
self.buckets = dict(zip(*kmeans([len(s.transformed[fields[0].name]) for s in self], n_buckets)))
8782
self.loader = DataLoader(dataset=self,
8883
batch_sampler=Sampler(self.buckets, batch_size, shuffle, distributed),
89-
collate_fn=self.collate_fn)
90-
91-
92-
class DataLoader(torch.utils.data.DataLoader):
93-
r"""
94-
DataLoader, matching with :class:`Dataset`.
95-
"""
96-
97-
def __init__(self, *args, **kwargs):
98-
super().__init__(*args, **kwargs)
99-
100-
def __iter__(self):
101-
for batch in super().__iter__():
102-
yield namedtuple('Batch', (f.name for f in batch.keys()))(*[f.compose(d) for f, d in batch.items()])
84+
collate_fn=lambda x: Batch(x))
10385

10486

10587
class Sampler(torch.utils.data.Sampler):

supar/utils/transform.py

Lines changed: 144 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -83,57 +83,6 @@ def save(self, path, sentences):
8383
f.write('\n'.join([str(i) for i in sentences]) + '\n')
8484

8585

86-
class Sentence(object):
87-
r"""
88-
A Sentence object holds a sentence with regard to specific data format.
89-
"""
90-
91-
def __init__(self, transform):
92-
self.transform = transform
93-
94-
# mapping from each nested field to their proper position
95-
self.maps = dict()
96-
# names of each field
97-
self.keys = set()
98-
for i, field in enumerate(self.transform):
99-
if not isinstance(field, Iterable):
100-
field = [field]
101-
for f in field:
102-
if f is not None:
103-
self.maps[f.name] = i
104-
self.keys.add(f.name)
105-
# original values and numericalized values of each position
106-
self.values = []
107-
self.transformed = {key: None for key in self.keys}
108-
109-
def __contains__(self, key):
110-
return key in self.keys
111-
112-
def __getattr__(self, name):
113-
if name in self.__dict__:
114-
return self.__dict__[name]
115-
elif name in self.maps:
116-
return self.values[self.maps[name]]
117-
else:
118-
raise AttributeError
119-
120-
def __setattr__(self, name, value):
121-
if 'keys' in self.__dict__ and name in self:
122-
index = self.maps[name]
123-
if index >= len(self.values):
124-
self.__dict__[name] = value
125-
else:
126-
self.values[index] = value
127-
else:
128-
self.__dict__[name] = value
129-
130-
def __getstate__(self):
131-
return vars(self)
132-
133-
def __setstate__(self, state):
134-
self.__dict__.update(state)
135-
136-
13786
class CoNLL(Transform):
13887
r"""
13988
The CoNLL object holds ten fields required for CoNLL-X data format :cite:`buchholz-marsi-2006-conll`.
@@ -402,77 +351,6 @@ def load(self, data, lang=None, proj=False, max_len=None, **kwargs):
402351
return sentences
403352

404353

405-
class CoNLLSentence(Sentence):
406-
r"""
407-
Sencence in CoNLL-X format.
408-
409-
Args:
410-
transform (CoNLL):
411-
A :class:`~supar.utils.transform.CoNLL` object.
412-
lines (list[str]):
413-
A list of strings composing a sentence in CoNLL-X format.
414-
Comments and non-integer IDs are permitted.
415-
416-
Examples:
417-
>>> lines = ['# text = But I found the location wonderful and the neighbors very kind.',
418-
'1\tBut\t_\t_\t_\t_\t_\t_\t_\t_',
419-
'2\tI\t_\t_\t_\t_\t_\t_\t_\t_',
420-
'3\tfound\t_\t_\t_\t_\t_\t_\t_\t_',
421-
'4\tthe\t_\t_\t_\t_\t_\t_\t_\t_',
422-
'5\tlocation\t_\t_\t_\t_\t_\t_\t_\t_',
423-
'6\twonderful\t_\t_\t_\t_\t_\t_\t_\t_',
424-
'7\tand\t_\t_\t_\t_\t_\t_\t_\t_',
425-
'7.1\tfound\t_\t_\t_\t_\t_\t_\t_\t_',
426-
'8\tthe\t_\t_\t_\t_\t_\t_\t_\t_',
427-
'9\tneighbors\t_\t_\t_\t_\t_\t_\t_\t_',
428-
'10\tvery\t_\t_\t_\t_\t_\t_\t_\t_',
429-
'11\tkind\t_\t_\t_\t_\t_\t_\t_\t_',
430-
'12\t.\t_\t_\t_\t_\t_\t_\t_\t_']
431-
>>> sentence = CoNLLSentence(transform, lines) # fields in transform are built from ptb.
432-
>>> sentence.arcs = [3, 3, 0, 5, 6, 3, 6, 9, 11, 11, 6, 3]
433-
>>> sentence.rels = ['cc', 'nsubj', 'root', 'det', 'nsubj', 'xcomp',
434-
'cc', 'det', 'dep', 'advmod', 'conj', 'punct']
435-
>>> sentence
436-
# text = But I found the location wonderful and the neighbors very kind.
437-
1 But _ _ _ _ 3 cc _ _
438-
2 I _ _ _ _ 3 nsubj _ _
439-
3 found _ _ _ _ 0 root _ _
440-
4 the _ _ _ _ 5 det _ _
441-
5 location _ _ _ _ 6 nsubj _ _
442-
6 wonderful _ _ _ _ 3 xcomp _ _
443-
7 and _ _ _ _ 6 cc _ _
444-
7.1 found _ _ _ _ _ _ _ _
445-
8 the _ _ _ _ 9 det _ _
446-
9 neighbors _ _ _ _ 11 dep _ _
447-
10 very _ _ _ _ 11 advmod _ _
448-
11 kind _ _ _ _ 6 conj _ _
449-
12 . _ _ _ _ 3 punct _ _
450-
"""
451-
452-
def __init__(self, transform, lines):
453-
super().__init__(transform)
454-
455-
self.values = []
456-
# record annotations for post-recovery
457-
self.annotations = dict()
458-
459-
for i, line in enumerate(lines):
460-
value = line.split('\t')
461-
if value[0].startswith('#') or not value[0].isdigit():
462-
self.annotations[-i-1] = line
463-
else:
464-
self.annotations[len(self.values)] = line
465-
self.values.append(value)
466-
self.values = list(zip(*self.values))
467-
468-
def __repr__(self):
469-
# cover the raw lines
470-
merged = {**self.annotations,
471-
**{i: '\t'.join(map(str, line))
472-
for i, line in enumerate(zip(*self.values))}}
473-
return '\n'.join(merged.values()) + '\n'
474-
475-
476354
class Tree(Transform):
477355
r"""
478356
The Tree object factorize a constituency tree into four fields,
@@ -741,6 +619,150 @@ def load(self, data, lang=None, max_len=None, **kwargs):
741619
return sentences
742620

743621

622+
class Batch(object):
623+
624+
def __init__(self, sentences):
625+
self.sentences = sentences
626+
self.transformed = {f.name: f.compose([s.transformed[f.name] for s in sentences])
627+
for f in sentences[0].transform.flattened_fields}
628+
self.fields = list(self.transformed.keys())
629+
630+
def __repr__(self):
631+
s = ', '.join([f"{name}" for name in self.fields])
632+
return f"{self.__class__.__name__}({s})"
633+
634+
def __getitem__(self, index):
635+
return self.transformed[self.fields[index]]
636+
637+
def __getattr__(self, name):
638+
if name in self.__dict__:
639+
return self.__dict__[name]
640+
if name in self.transformed:
641+
return self.transformed[name]
642+
if hasattr(self.sentences[0], name):
643+
return [getattr(s, name) for s in self.sentences]
644+
raise AttributeError
645+
646+
647+
class Sentence(object):
648+
649+
def __init__(self, transform):
650+
self.transform = transform
651+
652+
# mapping from each nested field to their proper position
653+
self.maps = dict()
654+
# names of each field
655+
self.keys = set()
656+
for i, field in enumerate(self.transform):
657+
if not isinstance(field, Iterable):
658+
field = [field]
659+
for f in field:
660+
if f is not None:
661+
self.maps[f.name] = i
662+
self.keys.add(f.name)
663+
# original values and numericalized values of each position
664+
self.values = []
665+
self.transformed = {key: None for key in self.keys}
666+
667+
def __contains__(self, key):
668+
return key in self.keys
669+
670+
def __getattr__(self, name):
671+
if name in self.__dict__:
672+
return self.__dict__[name]
673+
elif name in self.maps:
674+
return self.values[self.maps[name]]
675+
else:
676+
raise AttributeError
677+
678+
def __setattr__(self, name, value):
679+
if 'keys' in self.__dict__ and name in self:
680+
index = self.maps[name]
681+
if index >= len(self.values):
682+
self.__dict__[name] = value
683+
else:
684+
self.values[index] = value
685+
else:
686+
self.__dict__[name] = value
687+
688+
def __getstate__(self):
689+
return vars(self)
690+
691+
def __setstate__(self, state):
692+
self.__dict__.update(state)
693+
694+
695+
class CoNLLSentence(Sentence):
696+
r"""
697+
Sencence in CoNLL-X format.
698+
699+
Args:
700+
transform (CoNLL):
701+
A :class:`~supar.utils.transform.CoNLL` object.
702+
lines (list[str]):
703+
A list of strings composing a sentence in CoNLL-X format.
704+
Comments and non-integer IDs are permitted.
705+
706+
Examples:
707+
>>> lines = ['# text = But I found the location wonderful and the neighbors very kind.',
708+
'1\tBut\t_\t_\t_\t_\t_\t_\t_\t_',
709+
'2\tI\t_\t_\t_\t_\t_\t_\t_\t_',
710+
'3\tfound\t_\t_\t_\t_\t_\t_\t_\t_',
711+
'4\tthe\t_\t_\t_\t_\t_\t_\t_\t_',
712+
'5\tlocation\t_\t_\t_\t_\t_\t_\t_\t_',
713+
'6\twonderful\t_\t_\t_\t_\t_\t_\t_\t_',
714+
'7\tand\t_\t_\t_\t_\t_\t_\t_\t_',
715+
'7.1\tfound\t_\t_\t_\t_\t_\t_\t_\t_',
716+
'8\tthe\t_\t_\t_\t_\t_\t_\t_\t_',
717+
'9\tneighbors\t_\t_\t_\t_\t_\t_\t_\t_',
718+
'10\tvery\t_\t_\t_\t_\t_\t_\t_\t_',
719+
'11\tkind\t_\t_\t_\t_\t_\t_\t_\t_',
720+
'12\t.\t_\t_\t_\t_\t_\t_\t_\t_']
721+
>>> sentence = CoNLLSentence(transform, lines) # fields in transform are built from ptb.
722+
>>> sentence.arcs = [3, 3, 0, 5, 6, 3, 6, 9, 11, 11, 6, 3]
723+
>>> sentence.rels = ['cc', 'nsubj', 'root', 'det', 'nsubj', 'xcomp',
724+
'cc', 'det', 'dep', 'advmod', 'conj', 'punct']
725+
>>> sentence
726+
# text = But I found the location wonderful and the neighbors very kind.
727+
1 But _ _ _ _ 3 cc _ _
728+
2 I _ _ _ _ 3 nsubj _ _
729+
3 found _ _ _ _ 0 root _ _
730+
4 the _ _ _ _ 5 det _ _
731+
5 location _ _ _ _ 6 nsubj _ _
732+
6 wonderful _ _ _ _ 3 xcomp _ _
733+
7 and _ _ _ _ 6 cc _ _
734+
7.1 found _ _ _ _ _ _ _ _
735+
8 the _ _ _ _ 9 det _ _
736+
9 neighbors _ _ _ _ 11 dep _ _
737+
10 very _ _ _ _ 11 advmod _ _
738+
11 kind _ _ _ _ 6 conj _ _
739+
12 . _ _ _ _ 3 punct _ _
740+
"""
741+
742+
def __init__(self, transform, lines):
743+
super().__init__(transform)
744+
745+
self.values = []
746+
# record annotations for post-recovery
747+
self.annotations = dict()
748+
749+
for i, line in enumerate(lines):
750+
value = line.split('\t')
751+
if value[0].startswith('#') or not value[0].isdigit():
752+
self.annotations[-i-1] = line
753+
else:
754+
self.annotations[len(self.values)] = line
755+
self.values.append(value)
756+
self.values = list(zip(*self.values))
757+
758+
def __repr__(self):
759+
# cover the raw lines
760+
merged = {**self.annotations,
761+
**{i: '\t'.join(map(str, line))
762+
for i, line in enumerate(zip(*self.values))}}
763+
return '\n'.join(merged.values()) + '\n'
764+
765+
744766
class TreeSentence(Sentence):
745767
r"""
746768
Args:

0 commit comments

Comments
 (0)