9
9
10
10
class Dataset (torch .utils .data .Dataset ):
11
11
r"""
12
- Dataset that is compatible with :class:`torch.utils.data.Dataset`.
13
- This serves as a wrapper for manipulating all data fields
14
- with the operating behaviours defined in :class:`Transform`.
12
+ Dataset that is compatible with :class:`torch.utils.data.Dataset`, serving as a wrapper for manipulating all data fields
13
+ with the operating behaviours defined in :class:`~supar.utils.transform.Transform`.
15
14
The data fields of all the instantiated sentences can be accessed as an attribute of the dataset.
16
15
17
16
Args:
18
17
transform (Transform):
19
- An instance of :class:`Transform` and its derivations.
20
- The instance holds a series of loading and processing behaviours with regard to the specfic data format.
18
+ An instance of :class:`~supar.utils.transform. Transform` or its derivations.
19
+ The instance holds a series of loading and processing behaviours with regard to the specific data format.
21
20
data (list[list] or str):
22
- A list of instances or a filename.
23
- This will be passed into :meth:`transform.load`.
21
+ A list of instances or a filename that will be passed into :meth:`transform.load`.
24
22
kwargs (dict):
25
- Keyword arguments that will be passed into :meth:`transform.load` together with `data`
26
- to control the loading behaviour.
23
+ Together with `data`, kwargs will be passed into :meth:`transform.load` to control the loading behaviour.
27
24
28
25
Attributes:
29
26
transform (Transform):
30
- An instance of :class:`Transform`.
27
+ An instance of :class:`~supar.utils.transform. Transform`.
31
28
sentences (list[Sentence]):
32
29
A list of sentences loaded from the data.
33
30
Each sentence includes fields obeying the data format defined in ``transform``.
@@ -54,10 +51,7 @@ def __len__(self):
54
51
return len (self .sentences )
55
52
56
53
def __getitem__ (self , index ):
57
- if not hasattr (self , 'fields' ):
58
- raise RuntimeError ("The fields are not numericalized. Please build the dataset first." )
59
- for d in self .fields .values ():
60
- yield d [index ]
54
+ return self .sentences [index ]
61
55
62
56
def __getattr__ (self , name ):
63
57
if name in self .__dict__ :
@@ -67,9 +61,7 @@ def __getattr__(self, name):
67
61
def __setattr__ (self , name , value ):
68
62
if 'sentences' in self .__dict__ and name in self .sentences [0 ]:
69
63
# restore the order of sequences in the buckets
70
- indices = torch .tensor ([i
71
- for bucket in self .buckets .values ()
72
- for i in bucket ]).argsort ()
64
+ indices = torch .tensor ([i for bucket in self .buckets .values () for i in bucket ]).argsort ()
73
65
for index , sentence in zip (indices , self .sentences ):
74
66
setattr (sentence , name , value [index ])
75
67
else :
@@ -83,19 +75,17 @@ def __setstate__(self, state):
83
75
self .__dict__ .update (state )
84
76
85
77
def collate_fn (self , batch ):
86
- return {f : d for f , d in zip (self .fields .keys (), zip (* batch ))}
78
+ if not hasattr (self , 'fields' ):
79
+ raise RuntimeError ("The fields are not numericalized yet. Please build the dataset first." )
80
+ return {f : [s .transformed [f .name ] for s in batch ] for f in self .fields }
87
81
88
82
def build (self , batch_size , n_buckets = 1 , shuffle = False , distributed = False ):
89
83
# numericalize all fields
90
84
self .fields = self .transform (self .sentences )
91
85
# NOTE: the final bucket count is roughly equal to n_buckets
92
- self .lengths = [len (i ) for i in self .fields [next (iter (self .fields ))]]
93
- self .buckets = dict (zip (* kmeans (self .lengths , n_buckets )))
86
+ self .buckets = dict (zip (* kmeans ([len (s .transformed [self .fields [0 ].name ]) for s in self ], n_buckets )))
94
87
self .loader = DataLoader (dataset = self ,
95
- batch_sampler = Sampler (buckets = self .buckets ,
96
- batch_size = batch_size ,
97
- shuffle = shuffle ,
98
- distributed = distributed ),
88
+ batch_sampler = Sampler (self .buckets , batch_size , shuffle , distributed ),
99
89
collate_fn = self .collate_fn )
100
90
101
91
@@ -109,7 +99,7 @@ def __init__(self, *args, **kwargs):
109
99
110
100
def __iter__ (self ):
111
101
for batch in super ().__iter__ ():
112
- yield namedtuple ('Batch' , [ f .name for f in batch .keys ()] )(* [f .compose (d ) for f , d in batch .items ()])
102
+ yield namedtuple ('Batch' , ( f .name for f in batch .keys ()) )(* [f .compose (d ) for f , d in batch .items ()])
113
103
114
104
115
105
class Sampler (torch .utils .data .Sampler ):
@@ -148,15 +138,14 @@ def __iter__(self):
148
138
g .manual_seed (self .epoch )
149
139
range_fn = torch .arange
150
140
# if `shuffle=True`, shuffle both the buckets and samples in each bucket
151
- # for distributed training, make sure each process genertes the same random sequence at each epoch
141
+ # for distributed training, make sure each process generates the same random sequence at each epoch
152
142
if self .shuffle :
153
143
def range_fn (x ):
154
144
return torch .randperm (x , generator = g )
155
145
total , count = 0 , 0
156
146
# TODO: more elegant way to deal with uneven data, which we directly discard right now
157
147
for i in range_fn (len (self .buckets )).tolist ():
158
- split_sizes = [(len (self .buckets [i ]) - j - 1 ) // self .chunks [i ] + 1
159
- for j in range (self .chunks [i ])]
148
+ split_sizes = [(len (self .buckets [i ]) - j - 1 ) // self .chunks [i ] + 1 for j in range (self .chunks [i ])]
160
149
# DON'T use `torch.chunk` which may return wrong number of chunks
161
150
for batch in range_fn (len (self .buckets [i ])).split (split_sizes ):
162
151
if count == self .samples :
0 commit comments