Skip to content

Commit 575e278

Browse files
committed
Avoid duplicate field storage
1 parent d4168ff commit 575e278

File tree

2 files changed

+17
-21
lines changed

2 files changed

+17
-21
lines changed

supar/utils/data.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class Dataset(torch.utils.data.Dataset):
3434
Each sentence includes fields obeying the data format defined in ``transform``.
3535
"""
3636

37-
def __init__(self, transform: Transform, data: Union[List[List], str], **kwargs) -> Dataset:
37+
def __init__(self, transform: Transform, data: Union[str, List[List]], **kwargs) -> Dataset:
3838
super(Dataset, self).__init__()
3939

4040
self.transform = transform
@@ -82,10 +82,10 @@ def build(self, batch_size: int, n_buckets: int = 1, shuffle: bool = False, dist
8282
# numericalize all fields
8383
fields = self.transform(self.sentences)
8484
# NOTE: the final bucket count is roughly equal to n_buckets
85-
self.buckets = dict(zip(*kmeans([len(s.transformed[fields[0].name]) for s in self], n_buckets)))
85+
self.buckets = dict(zip(*kmeans([len(s.fields[fields[0].name]) for s in self], n_buckets)))
8686
self.loader = DataLoader(dataset=self,
8787
batch_sampler=Sampler(self.buckets, batch_size, shuffle, distributed),
88-
collate_fn=lambda x: Batch(x))
88+
collate_fn=lambda x: Batch(self.transform, x))
8989
return self
9090

9191

supar/utils/transform.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __call__(self, sentences):
4444
# numericalize the fields of each sentence
4545
for sentence in progress_bar(sentences):
4646
for f in self.flattened_fields:
47-
sentence.transformed[f.name] = f.transform([getattr(sentence, f.name)])[0]
47+
sentence.fields[f.name] = f.transform([getattr(sentence, f.name)])[0]
4848
return self.flattened_fields
4949

5050
def __getitem__(self, index):
@@ -322,12 +322,12 @@ def istree(cls, sequence: List[int], proj: bool = False, multiroot: bool = False
322322

323323
def load(
324324
self,
325-
data: Union[List[List], str],
325+
data: Union[str, List[List]],
326326
lang: Optional[str] = None,
327327
proj: bool = False,
328328
max_len: Optional[int] = None,
329329
**kwargs
330-
) -> List['CoNLLSentence']:
330+
) -> List[CoNLLSentence]:
331331
r"""
332332
Loads the data in CoNLL-X format.
333333
Also supports for loading data from CoNLL-U file with comments and non-integer IDs.
@@ -622,11 +622,11 @@ def track(node):
622622

623623
def load(
624624
self,
625-
data: Union[List[List], str],
625+
data: Union[str, List[List]],
626626
lang: Optional[str] = None,
627627
max_len: Optional[int] = None,
628628
**kwargs
629-
) -> List['TreeSentence']:
629+
) -> List[TreeSentence]:
630630
r"""
631631
Args:
632632
data (list[list] or str):
@@ -665,24 +665,22 @@ def load(
665665

666666
class Batch(object):
667667

668-
def __init__(self, sentences):
668+
def __init__(self, transform, sentences):
669669
self.sentences = sentences
670-
self.transformed = {f.name: f.compose([s.transformed[f.name] for s in sentences])
671-
for f in sentences[0].transform.flattened_fields}
672-
self.fields = list(self.transformed.keys())
670+
self.fields = {f.name: f.compose([s.fields[f.name] for s in sentences]) for f in transform.flattened_fields}
671+
self.names = list(self.fields.keys())
673672

674673
def __repr__(self):
675-
s = ', '.join([f"{name}" for name in self.fields])
676-
return f"{self.__class__.__name__}({s})"
674+
return f'{self.__class__.__name__}({", ".join([f"{name}" for name in self.names])})'
677675

678676
def __getitem__(self, index):
679-
return self.transformed[self.fields[index]]
677+
return self.fields[self.names[index]]
680678

681679
def __getattr__(self, name):
682680
if name in self.__dict__:
683681
return self.__dict__[name]
684-
if name in self.transformed:
685-
return self.transformed[name]
682+
if name in self.fields:
683+
return self.fields[name]
686684
if hasattr(self.sentences[0], name):
687685
return [getattr(s, name) for s in self.sentences]
688686
raise AttributeError
@@ -691,13 +689,11 @@ def __getattr__(self, name):
691689
class Sentence(object):
692690

693691
def __init__(self, transform):
694-
self.transform = transform
695-
696692
# mapping from each nested field to their proper position
697693
self.maps = dict()
698694
# names of each field
699695
self.keys = set()
700-
for i, field in enumerate(self.transform):
696+
for i, field in enumerate(transform):
701697
if not isinstance(field, Iterable):
702698
field = [field]
703699
for f in field:
@@ -706,7 +702,7 @@ def __init__(self, transform):
706702
self.keys.add(f.name)
707703
# original values and numericalized values of each position
708704
self.values = []
709-
self.transformed = {key: None for key in self.keys}
705+
self.fields = {key: None for key in self.keys}
710706

711707
def __contains__(self, key):
712708
return key in self.keys

0 commit comments

Comments
 (0)