@@ -56,8 +56,6 @@ def rouge_fn(predictions, labels):
56
56
metric names to numerical values.
57
57
eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`):
58
58
Validation data to be used to generate predictions for the `metric_fn`.
59
- metric_fn_kwargs (`dict`, *optional*):
60
- Additional keyword arguments to be passed to the metric_fn.
61
59
output_cols (`List[str], *optional*):
62
60
A list of columns to be retained from the model output as the predictions. Defaults to all.
63
61
label_cols ('`List[str]`, *optional*'):
@@ -74,7 +72,6 @@ def __init__(
74
72
self ,
75
73
metric_fn : Callable ,
76
74
eval_dataset : Union [tf .data .Dataset , np .ndarray , tf .Tensor , tuple , dict ],
77
- metric_fn_kwargs : Optional [dict ] = None ,
78
75
output_cols : Optional [List [str ]] = None ,
79
76
label_cols : Optional [List [str ]] = None ,
80
77
batch_size : Optional [int ] = None ,
@@ -94,12 +91,6 @@ def __init__(
94
91
self .eval_dataset = eval_dataset
95
92
self .predict_with_generate = predict_with_generate
96
93
self .output_cols = output_cols
97
- self .metric_fn_kwargs = metric_fn_kwargs or dict ()
98
-
99
- if hasattr (self .model , "encoder" ) and self .model .encoder .main_input_name != self .model .main_input_name :
100
- self .main_input_name = self .model .encoder .main_input_name
101
- else :
102
- self .main_input_name = self .model .main_input_name
103
94
104
95
# This next block attempts to parse out which elements of the dataset should be appended to the labels list
105
96
# that is passed to the metric_fn
@@ -123,32 +114,75 @@ def __init__(
123
114
self .label_cols = ["labels" ]
124
115
self .use_keras_label = False
125
116
logging .warning ("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key." )
117
+ elif "start_positions" in input_spec and "end_positions" in input_spec :
118
+ self .label_cols = ["start_positions" , "end_positions" ]
119
+ self .use_keras_label = False
120
+ logging .warning (
121
+ "No label_cols specified for KerasMetricCallback, assuming you want the "
122
+ "start_positions and end_positions keys."
123
+ )
126
124
else :
127
125
raise ValueError ("Could not autodetect label_cols for KerasMetricCallback, please specify them!" )
128
- if parse (tf .__version__ ). minor < parse ("2.7" ):
126
+ if parse (tf .__version__ ) < parse ("2.7" ):
129
127
logging .warning ("TF versions less than 2.7 may encounter issues with KerasMetricCallback!" )
130
128
131
129
@staticmethod
132
- def _concatenate_batches (batches ):
133
- # Flattens Numpy array batches into a list of single samples, where each sample is still np.ndarray
134
- return [sample for batch in batches for sample in batch ]
130
+ def _concatenate_batches (batches , padding_index = - 100 ):
131
+ # If all batches are unidimensional or same length, do a simple concatenation
132
+ if batches [0 ].ndim == 1 or all ([batch .shape [1 ] == batches [0 ].shape [1 ] for batch in batches ]):
133
+ return np .concatenate (batches , axis = 0 )
134
+
135
+ # Welp, they're not the same length. Let's do some padding
136
+ max_len = max ([batch .shape [1 ] for batch in batches ])
137
+ num_samples = sum ([batch .shape [0 ] for batch in batches ])
138
+ output = np .full_like (
139
+ batches [0 ], fill_value = padding_index , shape = [num_samples , max_len ] + list (batches [0 ].shape [2 :])
140
+ )
141
+ # i keeps track of which part of the concatenated array we're writing the next batch to
142
+ i = 0
143
+ for batch in batches :
144
+ output [i : i + len (batch ), : batch .shape [1 ]] = batch
145
+ i += len (batch )
146
+ return output
135
147
136
148
def _postprocess_predictions_or_labels (self , inputs ):
137
149
if isinstance (inputs [0 ], dict ):
138
150
outputs = dict ()
139
151
for key in inputs [0 ].keys ():
140
- outputs [key ] = self ._concatenate_batches (batch [key ] for batch in inputs )
152
+ outputs [key ] = self ._concatenate_batches ([batch [key ] for batch in inputs ])
153
+ # If it's a dict with only one key, just return the array
154
+ if len (outputs ) == 1 :
155
+ outputs = list (outputs .values ())[0 ]
141
156
elif isinstance (inputs [0 ], list ) or isinstance (inputs [0 ], tuple ):
142
157
outputs = []
143
158
for input_list in zip (* inputs ):
144
159
outputs .append (self ._concatenate_batches (input_list ))
160
+ if len (outputs ) == 1 :
161
+ outputs = outputs [0 ] # If it's a list with only one element, just return the array
145
162
elif isinstance (inputs [0 ], np .ndarray ):
146
163
outputs = self ._concatenate_batches (inputs )
164
+ elif isinstance (inputs [0 ], tf .Tensor ):
165
+ outputs = self ._concatenate_batches ([tensor .numpy () for tensor in inputs ])
147
166
else :
148
167
raise TypeError (f"Couldn't handle batch of type { type (inputs [0 ])} !" )
149
168
return outputs
150
169
151
170
def on_epoch_end (self , epoch , logs = None ):
171
+ if hasattr (self .model , "config" ):
172
+ ignore_keys = getattr (self .model .config , "keys_to_ignore_at_inference" , [])
173
+ else :
174
+ ignore_keys = []
175
+
176
+ main_input_name = None
177
+ if self .predict_with_generate :
178
+ # This dense conditional recognizes the case where we have an encoder-decoder model, but
179
+ # avoids getting tangled up when we just have a model with a layer called 'encoder'
180
+ if hasattr (self .model , "encoder" ) and hasattr (self .model .encoder , "main_input_name" ):
181
+ if self .model .encoder .main_input_name != self .model .main_input_name :
182
+ main_input_name = self .model .encoder .main_input_name
183
+ else :
184
+ main_input_name = getattr (self .model , "main_input_name" , "input_ids" )
185
+
152
186
prediction_list = []
153
187
label_list = []
154
188
@@ -160,7 +194,7 @@ def on_epoch_end(self, epoch, logs=None):
160
194
labels = None
161
195
if self .predict_with_generate :
162
196
if isinstance (batch , dict ):
163
- generation_inputs = batch [self . main_input_name ]
197
+ generation_inputs = batch [main_input_name ]
164
198
attention_mask = batch .get ("attention_mask" , None )
165
199
else :
166
200
generation_inputs = batch
@@ -169,9 +203,14 @@ def on_epoch_end(self, epoch, logs=None):
169
203
predictions = self .model .generate (generation_inputs , attention_mask = attention_mask )
170
204
else :
171
205
predictions = self .model .predict (batch )
172
- predictions = dict (predictions )
173
- if self .output_cols is not None :
174
- predictions = {key : predictions [key ] for key in self .output_cols }
206
+ if isinstance (predictions , dict ):
207
+ # This converts any dict-subclass to a regular dict
208
+ # Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
209
+ predictions = dict (predictions )
210
+ if self .output_cols is not None :
211
+ predictions = {key : predictions [key ] for key in self .output_cols }
212
+ else :
213
+ predictions = {key : val for key , val in predictions .items () if key not in ignore_keys + ["loss" ]}
175
214
prediction_list .append (predictions )
176
215
if not self .use_keras_label :
177
216
labels = {key : batch [key ].numpy () for key in self .label_cols }
@@ -185,10 +224,10 @@ def on_epoch_end(self, epoch, logs=None):
185
224
raise TypeError (f"Confused by labels of type { type (labels )} " )
186
225
label_list .append (labels )
187
226
188
- prediction_list = self ._postprocess_predictions_or_labels (prediction_list )
189
- label_list = self ._postprocess_predictions_or_labels (label_list )
227
+ all_preds = self ._postprocess_predictions_or_labels (prediction_list )
228
+ all_labels = self ._postprocess_predictions_or_labels (label_list )
190
229
191
- metric_output = self .metric_fn (prediction_list , label_list , ** self . metric_fn_kwargs )
230
+ metric_output = self .metric_fn (( all_preds , all_labels ) )
192
231
if not isinstance (metric_output , dict ):
193
232
raise TypeError (
194
233
f"metric_fn should return a dict mapping metric names to values but instead returned { metric_output } "
0 commit comments