Skip to content

MAINT Parameters validation for sklearn.metrics.average_precision_score #25313

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
9 changes: 9 additions & 0 deletions sklearn/datasets/_california_housing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from ._base import RemoteFileMetadata
from ._base import load_descr
from ..utils import Bunch
from ..utils._param_validation import validate_params


# The original data can be found at:
Expand All @@ -50,6 +51,14 @@
logger = logging.getLogger(__name__)


@validate_params(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this changes should be done in another PR. Could you revert this file.

{
"data_home": [str, None],
"download_if_missing": ["boolean"],
"return_X_y": ["boolean"],
"as_frame": ["boolean"],
}
)
def fetch_california_housing(
*, data_home=None, download_if_missing=True, return_X_y=False, as_frame=False
):
Expand Down
15 changes: 12 additions & 3 deletions sklearn/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from ..utils.multiclass import type_of_target
from ..utils.extmath import stable_cumsum
from ..utils.sparsefuncs import count_nonzero
from ..utils._param_validation import validate_params
from ..utils._param_validation import validate_params, StrOptions
from ..exceptions import UndefinedMetricWarning
from ..preprocessing import label_binarize
from ..utils._encode import _encode, _unique
Expand Down Expand Up @@ -112,6 +112,15 @@ def auc(x, y):
return area


@validate_params(
{
"y_true": ["array-like"],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is true that we are going to call column_or_1d that will transform any array-like into a NumPy array. Could you update the docstring of y_true and y_score to change ndarray by array-like

"y_score": ["array-like"],
"average": [StrOptions({"micro", "samples", "weighted", "macro"}), None],
"pos_label": [Real, str, "boolean"],
"sample_weight": ["array-like", None],
}
)
def average_precision_score(
y_true, y_score, *, average="macro", pos_label=1, sample_weight=None
):
Expand All @@ -137,10 +146,10 @@ def average_precision_score(

Parameters
----------
y_true : ndarray of shape (n_samples,) or (n_samples, n_classes)
y_true : array-like of shape (n_samples,) or (n_samples, n_classes)
True binary labels or binary label indicators.

y_score : ndarray of shape (n_samples,) or (n_samples, n_classes)
y_score : array-like of shape (n_samples,) or (n_samples, n_classes)
Target scores, can either be probability estimates of the positive
class, confidence values, or non-thresholded measure of decisions
(as returned by :term:`decision_function` on some classifiers).
Expand Down
2 changes: 2 additions & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,13 +100,15 @@ def _check_function_param_validation(
"sklearn.cluster.kmeans_plusplus",
"sklearn.covariance.empirical_covariance",
"sklearn.covariance.shrunk_covariance",
"sklearn.datasets.fetch_california_housing",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you revert this change.

"sklearn.datasets.make_sparse_coded_signal",
"sklearn.decomposition.sparse_encode",
"sklearn.feature_extraction.grid_to_graph",
"sklearn.feature_extraction.img_to_graph",
"sklearn.feature_extraction.image.extract_patches_2d",
"sklearn.metrics.accuracy_score",
"sklearn.metrics.auc",
"sklearn.metrics.average_precision_score",
"sklearn.metrics.cohen_kappa_score",
"sklearn.metrics.confusion_matrix",
"sklearn.metrics.det_curve",
Expand Down