3
3
import torch
4
4
import torch .nn as nn
5
5
from data_utils .sentence_utils import Constants
6
-
6
+ import numpy as np
7
7
8
8
def get_mask (lengths ):
9
9
seq_lens = lengths .view (- 1 , 1 )
@@ -16,40 +16,35 @@ def get_mask(lengths):
16
16
class CRF_Loss (nn .Module ):
17
17
def __init__ (self , tagset_size ):
18
18
super (CRF_Loss , self ).__init__ ()
19
- self .start_tag_idx = tagset_size
20
- self .stop_tag_idx = tagset_size + 1
19
+ self .start_tag = tagset_size
20
+ self .end_tag = tagset_size + 1
21
21
self .num_tags = tagset_size + 2
22
22
23
- #transition from y_i-1 to y_i, T[y_i, y_j] = y_i <= y_j
24
- #+2 added for start and end indices
25
23
self .transitions = nn .Parameter (torch .Tensor (self .num_tags , self .num_tags ))
26
- nn .init .uniform_ (self .transitions , - 0.1 , 0.1 )
24
+ nn .init .constant_ (self .transitions , - np . log ( self . num_tags ) )
27
25
28
- #no transition to start_tag, not transition from end tag
29
- self .transitions .data [self .start_tag_idx , :] = - 10000
30
- self .transitions .data [:, self .stop_tag_idx ] = - 10000
26
+ self .transitions .data [self .end_tag , :] = - 10000
27
+ self .transitions .data [:, self .start_tag ] = - 10000
31
28
32
29
def get_log_p_z (self , emissions , mask , seq_len ):
33
30
log_alpha = emissions [:, 0 ].clone ()
34
- log_alpha += self .transitions [self .start_tag_idx , : self .start_tag_idx ].unsqueeze (0 )
31
+ log_alpha += self .transitions [self .start_tag , : self .start_tag ].unsqueeze (0 )
35
32
36
33
for idx in range (1 , seq_len ):
37
34
broadcast_emissions = emissions [:, idx ].unsqueeze (1 )
38
- broadcast_transitions = self .transitions [
39
- : self .start_tag , : self .start_tag
40
- ].unsqueeze (0 )
35
+ broadcast_transitions = self .transitions [ : self .start_tag , : self .start_tag ].unsqueeze (0 )
41
36
broadcast_logprob = log_alpha .unsqueeze (2 )
42
37
score = broadcast_logprob + broadcast_emissions + broadcast_transitions
43
38
44
39
score = torch .logsumexp (score , 1 )
45
- log_alpha = score * mask [:, idx ].unsqueeze (1 ) + log_alpha .squeeze (1 ) * (
46
- 1.0 - mask [:, idx ].unsqueeze (1 )
47
- )
40
+ log_alpha = score * mask [:, idx ].unsqueeze (1 ) + log_alpha .squeeze (1 ) * (1.0 - mask [:, idx ].unsqueeze (1 ))
48
41
49
- log_alpha += self .transitions [: self .start_tag , self .end_tag ].unsqueeze (0 )
42
+ log_alpha += self .transitions [: self .start_tag , self .end_tag ].unsqueeze (0 )
50
43
return torch .logsumexp (log_alpha .squeeze (1 ), 1 )
51
44
52
45
def get_log_p_Y_X (self , emissions , mask , seq_len , tags ):
46
+ tags [tags < 0 ] = 0 # clone and then set
47
+
53
48
llh = self .transitions [self .start_tag , tags [:, 0 ]].unsqueeze (1 )
54
49
llh += emissions [:, 0 , :].gather (1 , tags [:, 0 ].view (- 1 , 1 )) * mask [:, 0 ].unsqueeze (1 )
55
50
@@ -79,20 +74,14 @@ def log_likelihood(self, emissions, tags):
79
74
def forward (self , emissions , tags ):
80
75
return self .log_likelihood (emissions , tags )
81
76
82
- def inference (self , emissions , lengths ):
83
- return self .viterbi_decode (emissions , lengths )
84
-
85
77
def viterbi_decode (self , emissions , lengths ):
86
78
mask = get_mask (lengths )
87
79
seq_len = emissions .shape [1 ]
88
80
89
81
log_prob = emissions [:, 0 ].clone ()
90
82
log_prob += self .transitions [self .start_tag , : self .start_tag ].unsqueeze (0 )
91
83
92
-
93
- end_scores = log_prob + self .transitions [
94
- : self .start_tag , self .end_tag
95
- ].unsqueeze (0 )
84
+ end_scores = log_prob + self .transitions [: self .start_tag , self .end_tag ].unsqueeze (0 )
96
85
97
86
best_scores_list = []
98
87
best_scores_list .append (end_scores .unsqueeze (1 ))
@@ -101,20 +90,12 @@ def viterbi_decode(self, emissions, lengths):
101
90
102
91
for idx in range (1 , seq_len ):
103
92
broadcast_emissions = emissions [:, idx ].unsqueeze (1 )
104
- broadcast_transmissions = self .transitions [
105
- : self .start_tag , : self .start_tag
106
- ].unsqueeze (0 )
93
+ broadcast_transmissions = self .transitions [: self .start_tag , : self .start_tag ].unsqueeze (0 )
107
94
broadcast_log_prob = log_prob .unsqueeze (2 )
108
-
109
95
score = broadcast_emissions + broadcast_transmissions + broadcast_log_prob
110
-
111
96
max_scores , max_score_indices = torch .max (score , 1 )
112
-
113
97
best_paths_list .append (max_score_indices .unsqueeze (1 ))
114
-
115
- end_scores = max_scores + self .transitions [
116
- : self .start_tag , self .end_tag
117
- ].unsqueeze (0 )
98
+ end_scores = max_scores + self .transitions [: self .start_tag , self .end_tag ].unsqueeze (0 )
118
99
119
100
best_scores_list .append (end_scores .unsqueeze (1 ))
120
101
log_prob = max_scores
@@ -128,48 +109,25 @@ def viterbi_decode(self, emissions, lengths):
128
109
padding_tensor = torch .tensor (Constants .PAD_ID ).long ()
129
110
130
111
labels = max_indices_from_scores [:, seq_len - 1 ]
131
- labels = self ._mask_tensor (labels , 1.0 - mask [:, seq_len - 1 ], padding_tensor )
132
-
112
+ labels = torch .where (1.0 - mask [:, seq_len - 1 ], padding_tensor , labels )
133
113
all_labels = labels .unsqueeze (1 ).long ()
134
114
135
115
for idx in range (seq_len - 2 , - 1 , - 1 ):
136
116
indices_for_lookup = all_labels [:, - 1 ].clone ()
137
- indices_for_lookup = torch .where (
138
- indices_for_lookup == self .ignore_index ,
139
- valid_index_tensor ,
140
- indices_for_lookup
141
- )
117
+ indices_for_lookup = torch .where (indices_for_lookup == self .ignore_index , valid_index_tensor ,
118
+ indices_for_lookup )
142
119
143
- indices_from_prev_pos = (
144
- best_paths [:, idx , :]
145
- .gather (1 , indices_for_lookup .view (- 1 , 1 ).long ())
146
- .squeeze (1 )
147
- )
148
- indices_from_prev_pos = torch .where (
149
- (1.0 - mask [:, idx + 1 ]),
150
- padding_tensor ,
151
- indices_from_prev_pos
152
- )
120
+ indices_from_prev_pos = best_paths [:, idx , :].gather (1 , indices_for_lookup .view (- 1 , 1 ).long ()).squeeze (1 )
121
+ indices_from_prev_pos = torch .where ((1.0 - mask [:, idx + 1 ]), padding_tensor , indices_from_prev_pos )
153
122
154
123
indices_from_max_scores = max_indices_from_scores [:, idx ]
155
- indices_from_max_scores = torch .where (
156
- mask [:, idx + 1 ],
157
- padding_tensor ,
158
- indices_from_max_scores
159
- )
124
+ indices_from_max_scores = torch .where (mask [:, idx + 1 ], padding_tensor , indices_from_max_scores )
160
125
161
- labels = torch .where (
162
- indices_from_max_scores == self .ignore_index ,
163
- indices_from_prev_pos ,
164
- indices_from_max_scores ,
165
- )
126
+ labels = torch .where (indices_from_max_scores == self .ignore_index , indices_from_prev_pos ,
127
+ indices_from_max_scores )
166
128
167
129
# Set to ignore_index if present state is not valid.
168
- labels = torch .where (
169
- (1 - mask [:, idx ]),
170
- padding_tensor ,
171
- labels
172
- )
130
+ labels = torch .where ((1 - mask [:, idx ]),padding_tensor , labels )
173
131
all_labels = torch .cat ((all_labels , labels .view (- 1 , 1 ).long ()), 1 )
174
132
175
133
return best_scores , torch .flip (all_labels , [1 ])
0 commit comments