@@ -42,8 +42,9 @@ def get_log_p_z(self, emissions, mask, seq_len):
42
42
log_alpha += self .transitions [: self .start_tag , self .end_tag ].unsqueeze (0 )
43
43
return torch .logsumexp (log_alpha .squeeze (1 ), 1 )
44
44
45
- def get_log_p_Y_X (self , emissions , mask , seq_len , tags ):
46
- tags [tags < 0 ] = 0 # clone and then set
45
+ def get_log_p_Y_X (self , emissions , mask , seq_len , orig_tags ):
46
+ tags = orig_tags .clone ()
47
+ tags [tags < 0 ] = 0
47
48
48
49
llh = self .transitions [self .start_tag , tags [:, 0 ]].unsqueeze (1 )
49
50
llh += emissions [:, 0 , :].gather (1 , tags [:, 0 ].view (- 1 , 1 )) * mask [:, 0 ].unsqueeze (1 )
@@ -106,28 +107,28 @@ def viterbi_decode(self, emissions, lengths):
106
107
_ , max_indices_from_scores = torch .max (best_scores , 2 )
107
108
108
109
valid_index_tensor = torch .tensor (0 ).long ()
109
- padding_tensor = torch .tensor (Constants .PAD_ID ).long ()
110
+ padding_tensor = torch .tensor (Constants .TAG_PAD_ID ).long ()
110
111
111
112
labels = max_indices_from_scores [:, seq_len - 1 ]
112
- labels = torch .where (1.0 - mask [:, seq_len - 1 ], padding_tensor , labels )
113
+ labels = torch .where (mask [:, seq_len - 1 ] != 1.0 , padding_tensor , labels )
113
114
all_labels = labels .unsqueeze (1 ).long ()
114
115
115
116
for idx in range (seq_len - 2 , - 1 , - 1 ):
116
117
indices_for_lookup = all_labels [:, - 1 ].clone ()
117
- indices_for_lookup = torch .where (indices_for_lookup == self . ignore_index , valid_index_tensor ,
118
+ indices_for_lookup = torch .where (indices_for_lookup == Constants . TAG_PAD_ID , valid_index_tensor ,
118
119
indices_for_lookup )
119
120
120
121
indices_from_prev_pos = best_paths [:, idx , :].gather (1 , indices_for_lookup .view (- 1 , 1 ).long ()).squeeze (1 )
121
- indices_from_prev_pos = torch .where (( 1.0 - mask [:, idx + 1 ]) , padding_tensor , indices_from_prev_pos )
122
+ indices_from_prev_pos = torch .where (mask [:, idx + 1 ] != 1.0 , padding_tensor , indices_from_prev_pos )
122
123
123
124
indices_from_max_scores = max_indices_from_scores [:, idx ]
124
- indices_from_max_scores = torch .where (mask [:, idx + 1 ], padding_tensor , indices_from_max_scores )
125
+ indices_from_max_scores = torch .where (mask [:, idx + 1 ] == 1.0 , padding_tensor , indices_from_max_scores )
125
126
126
- labels = torch .where (indices_from_max_scores == self . ignore_index , indices_from_prev_pos ,
127
+ labels = torch .where (indices_from_max_scores == Constants . TAG_PAD_ID , indices_from_prev_pos ,
127
128
indices_from_max_scores )
128
129
129
130
# Set to ignore_index if present state is not valid.
130
- labels = torch .where (( 1 - mask [:, idx ]), padding_tensor , labels )
131
+ labels = torch .where (mask [:, idx ] != 1.0 , padding_tensor , labels )
131
132
all_labels = torch .cat ((all_labels , labels .view (- 1 , 1 ).long ()), 1 )
132
133
133
134
return best_scores , torch .flip (all_labels , [1 ])
0 commit comments