Skip to content

Default argument pos_label=1 is not ignored in f1_score metric for multiclass classification #29734

Open
@slimebob1975

Description

@slimebob1975

Describe the bug

I get a ValueError for pos_label=1 default argument value to f1_score metric with argument average='micro' for the iris flower classification problem:

ValueError: pos_label=1 is not a valid label: It should be one of ['setosa' 'versicolor' 'virginica']

According to the documentation, the pos_label argument should be ignored for the multiclass problem:

https://scikit-learn.org/stable/modules/generated/sklearn.metrics.f1_score.html#f1-score

The class to report if average='binary' and the data is binary, otherwise this parameter is ignored.

Setting pos_label explicitly to None solves the problem and produces the expected output, see below.

Steps/Code to Reproduce

# Import necessary libraries
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import cross_val_score
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.metrics import make_scorer, f1_score

# Load the Iris dataset
data = load_iris()
X = data.data  # Features
y = data.target  # Labels

# Convert labels to string type
y = np.array([data.target_names[label] for label in data.target])

# Initialize the Linear Discriminant Analysis classifier
classifier = LinearDiscriminantAnalysis()

# Define a custom scorer using F1 score with average='micro'
f1_scorer = make_scorer(f1_score, average='micro', pos_label=1)

# Perform cross-validation with cross_val_score
try:
    scores = cross_val_score(classifier, X, y, cv=5, scoring=f1_scorer)
    print(f"Cross-validated F1 Scores (micro average): {scores}")
    print(f"Mean F1 Score: {np.mean(scores)}")
except ValueError as e:
    print(f"Error: {e}")

Expected Results

Cross-validated F1 Scores (micro average): [1.         1.         0.96666667 0.93333333 1.        ]
Mean F1 Score: 0.9800000000000001

Actual Results

Cross-validated F1 Scores (micro average): [nan nan nan nan nan]
Mean F1 Score: nan
[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\model_selection\_validation.py:1000](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/model_selection/_validation.py#line=999): UserWarning: Scoring failed. The score on this train-test partition for these parameters will be set to nan. Details: 
Traceback (most recent call last):
  File "[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\metrics\_scorer.py", line 139](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/metrics/_scorer.py#line=138), in __call__
    score = scorer._score(
            ^^^^^^^^^^^^^^
  File "[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\metrics\_scorer.py", line 371](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/metrics/_scorer.py#line=370), in _score
    y_pred = method_caller(
             ^^^^^^^^^^^^^^
  File "[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\metrics\_scorer.py", line 89](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/metrics/_scorer.py#line=88), in _cached_call
    result, _ = _get_response_values(
                ^^^^^^^^^^^^^^^^^^^^^
  File "[C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\utils\_response.py", line 204](file:///C:/Users/rgt0227/AppData/Local/anaconda3/Lib/site-packages/sklearn/utils/_response.py#line=203), in _get_response_values
    raise ValueError(
ValueError: pos_label=1 is not a valid label: It should be one of ['setosa' 'versicolor' 'virginica']

Versions

System:
    python: 3.11.5 | packaged by Anaconda, Inc. | (main, Sep 11 2023, 13:26:23) [MSC v.1916 64 bit (AMD64)]
executable: C:\Users\rgt0227\AppData\Local\anaconda3\python.exe
   machine: Windows-10-10.0.19045-SP0

Python dependencies:
      sklearn: 1.5.0
          pip: 23.2.1
   setuptools: 68.0.0
        numpy: 1.26.2
        scipy: 1.11.4
       Cython: None
       pandas: 2.1.1
   matplotlib: 3.8.0
       joblib: 1.2.0
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: mkl
    num_threads: 8
         prefix: mkl_rt
       filepath: C:\Users\rgt0227\AppData\Local\anaconda3\Library\bin\mkl_rt.2.dll
        version: 2023.1-Product
threading_layer: intel

       user_api: openmp
   internal_api: openmp
    num_threads: 8
         prefix: vcomp
       filepath: C:\Users\rgt0227\AppData\Local\anaconda3\Lib\site-packages\sklearn\.libs\vcomp140.dll
        version: None

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

Status

No status

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions