@@ -83,57 +83,6 @@ def save(self, path, sentences):
83
83
f .write ('\n ' .join ([str (i ) for i in sentences ]) + '\n ' )
84
84
85
85
86
- class Sentence (object ):
87
- r"""
88
- A Sentence object holds a sentence with regard to specific data format.
89
- """
90
-
91
- def __init__ (self , transform ):
92
- self .transform = transform
93
-
94
- # mapping from each nested field to their proper position
95
- self .maps = dict ()
96
- # names of each field
97
- self .keys = set ()
98
- for i , field in enumerate (self .transform ):
99
- if not isinstance (field , Iterable ):
100
- field = [field ]
101
- for f in field :
102
- if f is not None :
103
- self .maps [f .name ] = i
104
- self .keys .add (f .name )
105
- # original values and numericalized values of each position
106
- self .values = []
107
- self .transformed = {key : None for key in self .keys }
108
-
109
- def __contains__ (self , key ):
110
- return key in self .keys
111
-
112
- def __getattr__ (self , name ):
113
- if name in self .__dict__ :
114
- return self .__dict__ [name ]
115
- elif name in self .maps :
116
- return self .values [self .maps [name ]]
117
- else :
118
- raise AttributeError
119
-
120
- def __setattr__ (self , name , value ):
121
- if 'keys' in self .__dict__ and name in self :
122
- index = self .maps [name ]
123
- if index >= len (self .values ):
124
- self .__dict__ [name ] = value
125
- else :
126
- self .values [index ] = value
127
- else :
128
- self .__dict__ [name ] = value
129
-
130
- def __getstate__ (self ):
131
- return vars (self )
132
-
133
- def __setstate__ (self , state ):
134
- self .__dict__ .update (state )
135
-
136
-
137
86
class CoNLL (Transform ):
138
87
r"""
139
88
The CoNLL object holds ten fields required for CoNLL-X data format :cite:`buchholz-marsi-2006-conll`.
@@ -402,77 +351,6 @@ def load(self, data, lang=None, proj=False, max_len=None, **kwargs):
402
351
return sentences
403
352
404
353
405
- class CoNLLSentence (Sentence ):
406
- r"""
407
- Sencence in CoNLL-X format.
408
-
409
- Args:
410
- transform (CoNLL):
411
- A :class:`~supar.utils.transform.CoNLL` object.
412
- lines (list[str]):
413
- A list of strings composing a sentence in CoNLL-X format.
414
- Comments and non-integer IDs are permitted.
415
-
416
- Examples:
417
- >>> lines = ['# text = But I found the location wonderful and the neighbors very kind.',
418
- '1\tBut\t_\t_\t_\t_\t_\t_\t_\t_',
419
- '2\tI\t_\t_\t_\t_\t_\t_\t_\t_',
420
- '3\tfound\t_\t_\t_\t_\t_\t_\t_\t_',
421
- '4\tthe\t_\t_\t_\t_\t_\t_\t_\t_',
422
- '5\tlocation\t_\t_\t_\t_\t_\t_\t_\t_',
423
- '6\twonderful\t_\t_\t_\t_\t_\t_\t_\t_',
424
- '7\tand\t_\t_\t_\t_\t_\t_\t_\t_',
425
- '7.1\tfound\t_\t_\t_\t_\t_\t_\t_\t_',
426
- '8\tthe\t_\t_\t_\t_\t_\t_\t_\t_',
427
- '9\tneighbors\t_\t_\t_\t_\t_\t_\t_\t_',
428
- '10\tvery\t_\t_\t_\t_\t_\t_\t_\t_',
429
- '11\tkind\t_\t_\t_\t_\t_\t_\t_\t_',
430
- '12\t.\t_\t_\t_\t_\t_\t_\t_\t_']
431
- >>> sentence = CoNLLSentence(transform, lines) # fields in transform are built from ptb.
432
- >>> sentence.arcs = [3, 3, 0, 5, 6, 3, 6, 9, 11, 11, 6, 3]
433
- >>> sentence.rels = ['cc', 'nsubj', 'root', 'det', 'nsubj', 'xcomp',
434
- 'cc', 'det', 'dep', 'advmod', 'conj', 'punct']
435
- >>> sentence
436
- # text = But I found the location wonderful and the neighbors very kind.
437
- 1 But _ _ _ _ 3 cc _ _
438
- 2 I _ _ _ _ 3 nsubj _ _
439
- 3 found _ _ _ _ 0 root _ _
440
- 4 the _ _ _ _ 5 det _ _
441
- 5 location _ _ _ _ 6 nsubj _ _
442
- 6 wonderful _ _ _ _ 3 xcomp _ _
443
- 7 and _ _ _ _ 6 cc _ _
444
- 7.1 found _ _ _ _ _ _ _ _
445
- 8 the _ _ _ _ 9 det _ _
446
- 9 neighbors _ _ _ _ 11 dep _ _
447
- 10 very _ _ _ _ 11 advmod _ _
448
- 11 kind _ _ _ _ 6 conj _ _
449
- 12 . _ _ _ _ 3 punct _ _
450
- """
451
-
452
- def __init__ (self , transform , lines ):
453
- super ().__init__ (transform )
454
-
455
- self .values = []
456
- # record annotations for post-recovery
457
- self .annotations = dict ()
458
-
459
- for i , line in enumerate (lines ):
460
- value = line .split ('\t ' )
461
- if value [0 ].startswith ('#' ) or not value [0 ].isdigit ():
462
- self .annotations [- i - 1 ] = line
463
- else :
464
- self .annotations [len (self .values )] = line
465
- self .values .append (value )
466
- self .values = list (zip (* self .values ))
467
-
468
- def __repr__ (self ):
469
- # cover the raw lines
470
- merged = {** self .annotations ,
471
- ** {i : '\t ' .join (map (str , line ))
472
- for i , line in enumerate (zip (* self .values ))}}
473
- return '\n ' .join (merged .values ()) + '\n '
474
-
475
-
476
354
class Tree (Transform ):
477
355
r"""
478
356
The Tree object factorize a constituency tree into four fields,
@@ -741,6 +619,150 @@ def load(self, data, lang=None, max_len=None, **kwargs):
741
619
return sentences
742
620
743
621
622
+ class Batch (object ):
623
+
624
+ def __init__ (self , sentences ):
625
+ self .sentences = sentences
626
+ self .transformed = {f .name : f .compose ([s .transformed [f .name ] for s in sentences ])
627
+ for f in sentences [0 ].transform .flattened_fields }
628
+ self .fields = list (self .transformed .keys ())
629
+
630
+ def __repr__ (self ):
631
+ s = ', ' .join ([f"{ name } " for name in self .fields ])
632
+ return f"{ self .__class__ .__name__ } ({ s } )"
633
+
634
+ def __getitem__ (self , index ):
635
+ return self .transformed [self .fields [index ]]
636
+
637
+ def __getattr__ (self , name ):
638
+ if name in self .__dict__ :
639
+ return self .__dict__ [name ]
640
+ if name in self .transformed :
641
+ return self .transformed [name ]
642
+ if hasattr (self .sentences [0 ], name ):
643
+ return [getattr (s , name ) for s in self .sentences ]
644
+ raise AttributeError
645
+
646
+
647
+ class Sentence (object ):
648
+
649
+ def __init__ (self , transform ):
650
+ self .transform = transform
651
+
652
+ # mapping from each nested field to their proper position
653
+ self .maps = dict ()
654
+ # names of each field
655
+ self .keys = set ()
656
+ for i , field in enumerate (self .transform ):
657
+ if not isinstance (field , Iterable ):
658
+ field = [field ]
659
+ for f in field :
660
+ if f is not None :
661
+ self .maps [f .name ] = i
662
+ self .keys .add (f .name )
663
+ # original values and numericalized values of each position
664
+ self .values = []
665
+ self .transformed = {key : None for key in self .keys }
666
+
667
+ def __contains__ (self , key ):
668
+ return key in self .keys
669
+
670
+ def __getattr__ (self , name ):
671
+ if name in self .__dict__ :
672
+ return self .__dict__ [name ]
673
+ elif name in self .maps :
674
+ return self .values [self .maps [name ]]
675
+ else :
676
+ raise AttributeError
677
+
678
+ def __setattr__ (self , name , value ):
679
+ if 'keys' in self .__dict__ and name in self :
680
+ index = self .maps [name ]
681
+ if index >= len (self .values ):
682
+ self .__dict__ [name ] = value
683
+ else :
684
+ self .values [index ] = value
685
+ else :
686
+ self .__dict__ [name ] = value
687
+
688
+ def __getstate__ (self ):
689
+ return vars (self )
690
+
691
+ def __setstate__ (self , state ):
692
+ self .__dict__ .update (state )
693
+
694
+
695
+ class CoNLLSentence (Sentence ):
696
+ r"""
697
+ Sencence in CoNLL-X format.
698
+
699
+ Args:
700
+ transform (CoNLL):
701
+ A :class:`~supar.utils.transform.CoNLL` object.
702
+ lines (list[str]):
703
+ A list of strings composing a sentence in CoNLL-X format.
704
+ Comments and non-integer IDs are permitted.
705
+
706
+ Examples:
707
+ >>> lines = ['# text = But I found the location wonderful and the neighbors very kind.',
708
+ '1\tBut\t_\t_\t_\t_\t_\t_\t_\t_',
709
+ '2\tI\t_\t_\t_\t_\t_\t_\t_\t_',
710
+ '3\tfound\t_\t_\t_\t_\t_\t_\t_\t_',
711
+ '4\tthe\t_\t_\t_\t_\t_\t_\t_\t_',
712
+ '5\tlocation\t_\t_\t_\t_\t_\t_\t_\t_',
713
+ '6\twonderful\t_\t_\t_\t_\t_\t_\t_\t_',
714
+ '7\tand\t_\t_\t_\t_\t_\t_\t_\t_',
715
+ '7.1\tfound\t_\t_\t_\t_\t_\t_\t_\t_',
716
+ '8\tthe\t_\t_\t_\t_\t_\t_\t_\t_',
717
+ '9\tneighbors\t_\t_\t_\t_\t_\t_\t_\t_',
718
+ '10\tvery\t_\t_\t_\t_\t_\t_\t_\t_',
719
+ '11\tkind\t_\t_\t_\t_\t_\t_\t_\t_',
720
+ '12\t.\t_\t_\t_\t_\t_\t_\t_\t_']
721
+ >>> sentence = CoNLLSentence(transform, lines) # fields in transform are built from ptb.
722
+ >>> sentence.arcs = [3, 3, 0, 5, 6, 3, 6, 9, 11, 11, 6, 3]
723
+ >>> sentence.rels = ['cc', 'nsubj', 'root', 'det', 'nsubj', 'xcomp',
724
+ 'cc', 'det', 'dep', 'advmod', 'conj', 'punct']
725
+ >>> sentence
726
+ # text = But I found the location wonderful and the neighbors very kind.
727
+ 1 But _ _ _ _ 3 cc _ _
728
+ 2 I _ _ _ _ 3 nsubj _ _
729
+ 3 found _ _ _ _ 0 root _ _
730
+ 4 the _ _ _ _ 5 det _ _
731
+ 5 location _ _ _ _ 6 nsubj _ _
732
+ 6 wonderful _ _ _ _ 3 xcomp _ _
733
+ 7 and _ _ _ _ 6 cc _ _
734
+ 7.1 found _ _ _ _ _ _ _ _
735
+ 8 the _ _ _ _ 9 det _ _
736
+ 9 neighbors _ _ _ _ 11 dep _ _
737
+ 10 very _ _ _ _ 11 advmod _ _
738
+ 11 kind _ _ _ _ 6 conj _ _
739
+ 12 . _ _ _ _ 3 punct _ _
740
+ """
741
+
742
+ def __init__ (self , transform , lines ):
743
+ super ().__init__ (transform )
744
+
745
+ self .values = []
746
+ # record annotations for post-recovery
747
+ self .annotations = dict ()
748
+
749
+ for i , line in enumerate (lines ):
750
+ value = line .split ('\t ' )
751
+ if value [0 ].startswith ('#' ) or not value [0 ].isdigit ():
752
+ self .annotations [- i - 1 ] = line
753
+ else :
754
+ self .annotations [len (self .values )] = line
755
+ self .values .append (value )
756
+ self .values = list (zip (* self .values ))
757
+
758
+ def __repr__ (self ):
759
+ # cover the raw lines
760
+ merged = {** self .annotations ,
761
+ ** {i : '\t ' .join (map (str , line ))
762
+ for i , line in enumerate (zip (* self .values ))}}
763
+ return '\n ' .join (merged .values ()) + '\n '
764
+
765
+
744
766
class TreeSentence (Sentence ):
745
767
r"""
746
768
Args:
0 commit comments