@@ -44,7 +44,7 @@ def __call__(self, sentences):
44
44
# numericalize the fields of each sentence
45
45
for sentence in progress_bar (sentences ):
46
46
for f in self .flattened_fields :
47
- sentence .transformed [f .name ] = f .transform ([getattr (sentence , f .name )])[0 ]
47
+ sentence .fields [f .name ] = f .transform ([getattr (sentence , f .name )])[0 ]
48
48
return self .flattened_fields
49
49
50
50
def __getitem__ (self , index ):
@@ -322,12 +322,12 @@ def istree(cls, sequence: List[int], proj: bool = False, multiroot: bool = False
322
322
323
323
def load (
324
324
self ,
325
- data : Union [List [List ], str ],
325
+ data : Union [str , List [List ]],
326
326
lang : Optional [str ] = None ,
327
327
proj : bool = False ,
328
328
max_len : Optional [int ] = None ,
329
329
** kwargs
330
- ) -> List [' CoNLLSentence' ]:
330
+ ) -> List [CoNLLSentence ]:
331
331
r"""
332
332
Loads the data in CoNLL-X format.
333
333
Also supports for loading data from CoNLL-U file with comments and non-integer IDs.
@@ -622,11 +622,11 @@ def track(node):
622
622
623
623
def load (
624
624
self ,
625
- data : Union [List [List ], str ],
625
+ data : Union [str , List [List ]],
626
626
lang : Optional [str ] = None ,
627
627
max_len : Optional [int ] = None ,
628
628
** kwargs
629
- ) -> List [' TreeSentence' ]:
629
+ ) -> List [TreeSentence ]:
630
630
r"""
631
631
Args:
632
632
data (list[list] or str):
@@ -665,24 +665,22 @@ def load(
665
665
666
666
class Batch (object ):
667
667
668
- def __init__ (self , sentences ):
668
+ def __init__ (self , transform , sentences ):
669
669
self .sentences = sentences
670
- self .transformed = {f .name : f .compose ([s .transformed [f .name ] for s in sentences ])
671
- for f in sentences [0 ].transform .flattened_fields }
672
- self .fields = list (self .transformed .keys ())
670
+ self .fields = {f .name : f .compose ([s .fields [f .name ] for s in sentences ]) for f in transform .flattened_fields }
671
+ self .names = list (self .fields .keys ())
673
672
674
673
def __repr__ (self ):
675
- s = ', ' .join ([f"{ name } " for name in self .fields ])
676
- return f"{ self .__class__ .__name__ } ({ s } )"
674
+ return f'{ self .__class__ .__name__ } ({ ", " .join ([f"{ name } " for name in self .names ])} )'
677
675
678
676
def __getitem__ (self , index ):
679
- return self .transformed [self .fields [index ]]
677
+ return self .fields [self .names [index ]]
680
678
681
679
def __getattr__ (self , name ):
682
680
if name in self .__dict__ :
683
681
return self .__dict__ [name ]
684
- if name in self .transformed :
685
- return self .transformed [name ]
682
+ if name in self .fields :
683
+ return self .fields [name ]
686
684
if hasattr (self .sentences [0 ], name ):
687
685
return [getattr (s , name ) for s in self .sentences ]
688
686
raise AttributeError
@@ -691,13 +689,11 @@ def __getattr__(self, name):
691
689
class Sentence (object ):
692
690
693
691
def __init__ (self , transform ):
694
- self .transform = transform
695
-
696
692
# mapping from each nested field to their proper position
697
693
self .maps = dict ()
698
694
# names of each field
699
695
self .keys = set ()
700
- for i , field in enumerate (self . transform ):
696
+ for i , field in enumerate (transform ):
701
697
if not isinstance (field , Iterable ):
702
698
field = [field ]
703
699
for f in field :
@@ -706,7 +702,7 @@ def __init__(self, transform):
706
702
self .keys .add (f .name )
707
703
# original values and numericalized values of each position
708
704
self .values = []
709
- self .transformed = {key : None for key in self .keys }
705
+ self .fields = {key : None for key in self .keys }
710
706
711
707
def __contains__ (self , key ):
712
708
return key in self .keys
0 commit comments