Skip to content

Commit 6beae76

Browse files
Fix KerasMetricCallback prediction with generate() and inference of column names (huggingface#15351)
* Fix prediction with generate() and the inference of column names Should now have very few differences with the PyTorch implementation * Minor edit to parent class * Update src/transformers/keras_callbacks.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Explaining the dict conversion * Putting main_input_name back * Fixes to main_input_name Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
1 parent da5ef25 commit 6beae76

File tree

1 file changed

+60
-21
lines changed

1 file changed

+60
-21
lines changed

src/transformers/keras_callbacks.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,6 @@ def rouge_fn(predictions, labels):
5656
metric names to numerical values.
5757
eval_dataset (`tf.data.Dataset` or `dict` or `tuple` or `np.ndarray` or `tf.Tensor`):
5858
Validation data to be used to generate predictions for the `metric_fn`.
59-
metric_fn_kwargs (`dict`, *optional*):
60-
Additional keyword arguments to be passed to the metric_fn.
6159
output_cols (`List[str], *optional*):
6260
A list of columns to be retained from the model output as the predictions. Defaults to all.
6361
label_cols ('`List[str]`, *optional*'):
@@ -74,7 +72,6 @@ def __init__(
7472
self,
7573
metric_fn: Callable,
7674
eval_dataset: Union[tf.data.Dataset, np.ndarray, tf.Tensor, tuple, dict],
77-
metric_fn_kwargs: Optional[dict] = None,
7875
output_cols: Optional[List[str]] = None,
7976
label_cols: Optional[List[str]] = None,
8077
batch_size: Optional[int] = None,
@@ -94,12 +91,6 @@ def __init__(
9491
self.eval_dataset = eval_dataset
9592
self.predict_with_generate = predict_with_generate
9693
self.output_cols = output_cols
97-
self.metric_fn_kwargs = metric_fn_kwargs or dict()
98-
99-
if hasattr(self.model, "encoder") and self.model.encoder.main_input_name != self.model.main_input_name:
100-
self.main_input_name = self.model.encoder.main_input_name
101-
else:
102-
self.main_input_name = self.model.main_input_name
10394

10495
# This next block attempts to parse out which elements of the dataset should be appended to the labels list
10596
# that is passed to the metric_fn
@@ -123,32 +114,75 @@ def __init__(
123114
self.label_cols = ["labels"]
124115
self.use_keras_label = False
125116
logging.warning("No label_cols specified for KerasMetricCallback, assuming you want the 'labels' key.")
117+
elif "start_positions" in input_spec and "end_positions" in input_spec:
118+
self.label_cols = ["start_positions", "end_positions"]
119+
self.use_keras_label = False
120+
logging.warning(
121+
"No label_cols specified for KerasMetricCallback, assuming you want the "
122+
"start_positions and end_positions keys."
123+
)
126124
else:
127125
raise ValueError("Could not autodetect label_cols for KerasMetricCallback, please specify them!")
128-
if parse(tf.__version__).minor < parse("2.7"):
126+
if parse(tf.__version__) < parse("2.7"):
129127
logging.warning("TF versions less than 2.7 may encounter issues with KerasMetricCallback!")
130128

131129
@staticmethod
132-
def _concatenate_batches(batches):
133-
# Flattens Numpy array batches into a list of single samples, where each sample is still np.ndarray
134-
return [sample for batch in batches for sample in batch]
130+
def _concatenate_batches(batches, padding_index=-100):
131+
# If all batches are unidimensional or same length, do a simple concatenation
132+
if batches[0].ndim == 1 or all([batch.shape[1] == batches[0].shape[1] for batch in batches]):
133+
return np.concatenate(batches, axis=0)
134+
135+
# Welp, they're not the same length. Let's do some padding
136+
max_len = max([batch.shape[1] for batch in batches])
137+
num_samples = sum([batch.shape[0] for batch in batches])
138+
output = np.full_like(
139+
batches[0], fill_value=padding_index, shape=[num_samples, max_len] + list(batches[0].shape[2:])
140+
)
141+
# i keeps track of which part of the concatenated array we're writing the next batch to
142+
i = 0
143+
for batch in batches:
144+
output[i : i + len(batch), : batch.shape[1]] = batch
145+
i += len(batch)
146+
return output
135147

136148
def _postprocess_predictions_or_labels(self, inputs):
137149
if isinstance(inputs[0], dict):
138150
outputs = dict()
139151
for key in inputs[0].keys():
140-
outputs[key] = self._concatenate_batches(batch[key] for batch in inputs)
152+
outputs[key] = self._concatenate_batches([batch[key] for batch in inputs])
153+
# If it's a dict with only one key, just return the array
154+
if len(outputs) == 1:
155+
outputs = list(outputs.values())[0]
141156
elif isinstance(inputs[0], list) or isinstance(inputs[0], tuple):
142157
outputs = []
143158
for input_list in zip(*inputs):
144159
outputs.append(self._concatenate_batches(input_list))
160+
if len(outputs) == 1:
161+
outputs = outputs[0] # If it's a list with only one element, just return the array
145162
elif isinstance(inputs[0], np.ndarray):
146163
outputs = self._concatenate_batches(inputs)
164+
elif isinstance(inputs[0], tf.Tensor):
165+
outputs = self._concatenate_batches([tensor.numpy() for tensor in inputs])
147166
else:
148167
raise TypeError(f"Couldn't handle batch of type {type(inputs[0])}!")
149168
return outputs
150169

151170
def on_epoch_end(self, epoch, logs=None):
171+
if hasattr(self.model, "config"):
172+
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
173+
else:
174+
ignore_keys = []
175+
176+
main_input_name = None
177+
if self.predict_with_generate:
178+
# This dense conditional recognizes the case where we have an encoder-decoder model, but
179+
# avoids getting tangled up when we just have a model with a layer called 'encoder'
180+
if hasattr(self.model, "encoder") and hasattr(self.model.encoder, "main_input_name"):
181+
if self.model.encoder.main_input_name != self.model.main_input_name:
182+
main_input_name = self.model.encoder.main_input_name
183+
else:
184+
main_input_name = getattr(self.model, "main_input_name", "input_ids")
185+
152186
prediction_list = []
153187
label_list = []
154188

@@ -160,7 +194,7 @@ def on_epoch_end(self, epoch, logs=None):
160194
labels = None
161195
if self.predict_with_generate:
162196
if isinstance(batch, dict):
163-
generation_inputs = batch[self.main_input_name]
197+
generation_inputs = batch[main_input_name]
164198
attention_mask = batch.get("attention_mask", None)
165199
else:
166200
generation_inputs = batch
@@ -169,9 +203,14 @@ def on_epoch_end(self, epoch, logs=None):
169203
predictions = self.model.generate(generation_inputs, attention_mask=attention_mask)
170204
else:
171205
predictions = self.model.predict(batch)
172-
predictions = dict(predictions)
173-
if self.output_cols is not None:
174-
predictions = {key: predictions[key] for key in self.output_cols}
206+
if isinstance(predictions, dict):
207+
# This converts any dict-subclass to a regular dict
208+
# Keras REALLY doesn't like it when we pass around a BatchEncoding or other derived class
209+
predictions = dict(predictions)
210+
if self.output_cols is not None:
211+
predictions = {key: predictions[key] for key in self.output_cols}
212+
else:
213+
predictions = {key: val for key, val in predictions.items() if key not in ignore_keys + ["loss"]}
175214
prediction_list.append(predictions)
176215
if not self.use_keras_label:
177216
labels = {key: batch[key].numpy() for key in self.label_cols}
@@ -185,10 +224,10 @@ def on_epoch_end(self, epoch, logs=None):
185224
raise TypeError(f"Confused by labels of type {type(labels)}")
186225
label_list.append(labels)
187226

188-
prediction_list = self._postprocess_predictions_or_labels(prediction_list)
189-
label_list = self._postprocess_predictions_or_labels(label_list)
227+
all_preds = self._postprocess_predictions_or_labels(prediction_list)
228+
all_labels = self._postprocess_predictions_or_labels(label_list)
190229

191-
metric_output = self.metric_fn(prediction_list, label_list, **self.metric_fn_kwargs)
230+
metric_output = self.metric_fn((all_preds, all_labels))
192231
if not isinstance(metric_output, dict):
193232
raise TypeError(
194233
f"metric_fn should return a dict mapping metric names to values but instead returned {metric_output}"

0 commit comments

Comments
 (0)