Skip to content

Commit 139e830

Browse files
authored
Update label2id in the model config for run_glue (huggingface#13334)
1 parent 6f3c99a commit 139e830

File tree

3 files changed

+9
-0
lines changed

3 files changed

+9
-0
lines changed

examples/pytorch/text-classification/run_glue.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,9 @@ def main():
380380
if label_to_id is not None:
381381
model.config.label2id = label_to_id
382382
model.config.id2label = {id: label for label, id in config.label2id.items()}
383+
elif data_args.task_name is not None and not is_regression:
384+
model.config.label2id = {l: i for i, l in enumerate(label_list)}
385+
model.config.id2label = {id: label for label, id in config.label2id.items()}
383386

384387
if data_args.max_seq_length > tokenizer.model_max_length:
385388
logger.warning(

examples/pytorch/text-classification/run_glue_no_trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,9 @@ def main():
288288
if label_to_id is not None:
289289
model.config.label2id = label_to_id
290290
model.config.id2label = {id: label for label, id in config.label2id.items()}
291+
elif args.task_name is not None and not is_regression:
292+
model.config.label2id = {l: i for i, l in enumerate(label_list)}
293+
model.config.id2label = {id: label for label, id in config.label2id.items()}
291294

292295
padding = "max_length" if args.pad_to_max_length else False
293296

examples/tensorflow/text-classification/run_glue.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,9 @@ def main():
355355
if label_to_id is not None:
356356
config.label2id = label_to_id
357357
config.id2label = {id: label for label, id in config.label2id.items()}
358+
elif data_args.task_name is not None and not is_regression:
359+
config.label2id = {l: i for i, l in enumerate(label_list)}
360+
config.id2label = {id: label for label, id in config.label2id.items()}
358361

359362
if data_args.max_seq_length > tokenizer.model_max_length:
360363
logger.warning(

0 commit comments

Comments
 (0)