diff --git a/doc/whats_new/_contributors.rst b/doc/whats_new/_contributors.rst index c74a2964e57bc..d56bdb9bc45ee 100644 --- a/doc/whats_new/_contributors.rst +++ b/doc/whats_new/_contributors.rst @@ -179,3 +179,5 @@ .. _Guillaume Lemaitre: https://github.com/glemaitre .. _Tim Head: https://betatim.github.io/ + +.. _Alex Shtoff: https://alexshtf.github.io/ \ No newline at end of file diff --git a/doc/whats_new/upcoming_changes/sklearn.metrics/31380.feature.rst b/doc/whats_new/upcoming_changes/sklearn.metrics/31380.feature.rst new file mode 100644 index 0000000000000..7aef8b28925f8 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.metrics/31380.feature.rst @@ -0,0 +1,8 @@ +The new functiom :func:`metrics.rec_curve` computes the Regression Error Characteristic +(REC) curve of error tolerances vs the fraction of samples below the tolerance (the CDF +of the regression errors). Suggested in the ICML'03 paper "Regression error +characteristic curves" by Bi and Bennett as the regression-variant of the ROC curve. + +The new class :class:`metrics.RecCurveDisplay` visualizes the REC curves. + +By :user:`Alex Shtoff ` \ No newline at end of file diff --git a/examples/miscellaneous/plot_rec_curve_visualization.py b/examples/miscellaneous/plot_rec_curve_visualization.py new file mode 100644 index 0000000000000..04988aa3d50d7 --- /dev/null +++ b/examples/miscellaneous/plot_rec_curve_visualization.py @@ -0,0 +1,105 @@ +""" +============================================= +Regression Error Characteristic curve display +============================================= + +We illustrate how to display the regression error characteristic curve, and how to use +it to compare various regression models. +""" + +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause + +# %% +# Data Loading and Preparation +# ---------------------------- +# +# Load the diabetes dataset. For simplicity, we only keep a single feature in the data. +# Then, we split the data and target into training and test sets. +from sklearn.datasets import fetch_california_housing +from sklearn.model_selection import train_test_split + +X, y = fetch_california_housing(return_X_y=True, as_frame=True) +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=0.2, random_state=42, shuffle=True +) + +# %% +# Linear regression model +# ----------------------- +# +# We create a linear regression model and fit it on the training after standard scaling +from sklearn.linear_model import LinearRegression +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + +lr_estimator = make_pipeline(StandardScaler(), LinearRegression()) +lr_estimator.fit(X_train, y_train) + +# %% +# Display REC curve using the test set. See that our linear regressor's REC curve +# dominates the curve of the default constant predictor - the median. This is an +# indicator that our model is doing something "reasonable". +from sklearn.metrics import RecCurveDisplay + +RecCurveDisplay.from_estimator(lr_estimator, X_test, y_test, name="Linear regression") + + +# %% +# Compare two REC curves of linear regression to KNN and histogram gradient boosting +# regressors. We can see a clear hierarchy here. KNN is strictly better than linear +# regression for any error tolerance, and HGBR is strictly better than both for any +# error tolerance. +import matplotlib.pyplot as plt + +from sklearn.ensemble import HistGradientBoostingRegressor +from sklearn.metrics import mean_absolute_error, root_mean_squared_error +from sklearn.neighbors import KNeighborsRegressor + +hgbr_estimator = HistGradientBoostingRegressor() +hgbr_estimator.fit(X_train, y_train) + +knn_estimator = make_pipeline(StandardScaler(), KNeighborsRegressor()) +knn_estimator.fit(X_train, y_train) + +pred_lr = lr_estimator.predict(X_test) +pred_hgbr = hgbr_estimator.predict(X_test) +pred_knn = knn_estimator.predict(X_test) + +lr_metrics = ( + f"RMSE = {root_mean_squared_error(pred_lr, y_test):.2f}, " + f"MAE = {mean_absolute_error(pred_lr, y_test):.2f}" +) +hgbr_metrics = ( + f"RMSE = {root_mean_squared_error(pred_hgbr, y_test):.2f}, " + f"MAE = {mean_absolute_error(pred_hgbr, y_test):.2f}" +) +knn_metrics = ( + f"RMSE = {root_mean_squared_error(pred_knn, y_test):.2f}, " + f"MAE = {mean_absolute_error(pred_knn, y_test):.2f}" +) + +fig, ax = plt.subplots() +RecCurveDisplay.from_predictions( + y_test, + pred_lr, + ax=ax, + name=f"Linear regression ({lr_metrics})", + plot_const_predictor=False, +) +RecCurveDisplay.from_predictions( + y_test, + pred_knn, + ax=ax, + name=f"KNN ({knn_metrics})", + plot_const_predictor=False, +) +RecCurveDisplay.from_predictions( + y_test, + pred_hgbr, + ax=ax, + name=f"Gradient Boosting Regressor ({hgbr_metrics})", +) +fig.show() + +# %% diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index ce86525acc368..b8281de7ae038 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -30,6 +30,7 @@ from ._plot.confusion_matrix import ConfusionMatrixDisplay from ._plot.det_curve import DetCurveDisplay from ._plot.precision_recall_curve import PrecisionRecallDisplay +from ._plot.rec_curve import RecCurveDisplay from ._plot.regression import PredictionErrorDisplay from ._plot.roc_curve import RocCurveDisplay from ._ranking import ( @@ -65,6 +66,7 @@ root_mean_squared_error, root_mean_squared_log_error, ) +from ._regression_characteristic import rec_curve from ._scorer import check_scoring, get_scorer, get_scorer_names, make_scorer from .cluster import ( adjusted_mutual_info_score, @@ -100,6 +102,7 @@ "DistanceMetric", "PrecisionRecallDisplay", "PredictionErrorDisplay", + "RecCurveDisplay", "RocCurveDisplay", "accuracy_score", "adjusted_mutual_info_score", @@ -168,6 +171,7 @@ "precision_score", "r2_score", "rand_score", + "rec_curve", "recall_score", "roc_auc_score", "roc_curve", diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py new file mode 100644 index 0000000000000..cbdbce557d1bc --- /dev/null +++ b/sklearn/metrics/_plot/rec_curve.py @@ -0,0 +1,363 @@ +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np + +from ...base import is_regressor # To check if estimator is a regressor +from ...metrics._regression_characteristic import rec_curve +from ...utils._plotting import _CurveDisplayMixin, _validate_style_kwargs +from ...utils.validation import check_is_fitted + + +class RecCurveDisplay(_CurveDisplayMixin): + """Regression Error Characteristic (REC) Curve visualization. + + It is recommended to use :func:`~sklearn.metrics.RecCurveDisplay.from_estimator` + or :func:`~sklearn.metrics.RecCurveDisplay.from_predictions` to create + a visualizer. + + For general information regarding `scikit-learn` visualization tools, see + the :ref:`Visualization Guide `. + + Parameters + ---------- + deviations : ndarray + Sorted unique error tolerance values (x-coordinates). + accuracy : ndarray + Corresponding accuracy values (y-coordinates). + estimator_name : str, default=None + Name of the estimator. If `None`, then the name will be `"Model"`. + loss : {'absolute', 'squared'}, default='absolute' + The loss function used to compute the REC curve. + constant_predictor_deviations : ndarray, default=None + Deviations for the constant predictor's REC curve. `None` if not plotted + or not computed. + constant_predictor_accuracy : ndarray, default=None + Accuracy for the constant predictor's REC curve. `None` if not plotted + or not computed. + constant_predictor_name : str, default=None + Name of the constant predictor. `None` if not plotted or not computed. + + Attributes + ---------- + line_ : matplotlib Artist of REC curve. + ax_ : matplotlib Axes + Axes with REC curve. + figure_ : matplotlib Figure + Figure containing the curve. + constant_predictor_line_ : matplotlib Artist, default=None + Constant predictor REC curve. Only defined if a constant predictor + was plotted. + + See Also + -------- + rec_curve : Compute Regression Error Characteristic (REC) curve. + + Notes + ----- + This class uses the heuristic suggested in the REC paper to clip the x axis + at the maximum error achieved by some reasonable constant predictor, such as + the mean or the median. This is a heuristic to make plots more interpratable, + since we focus only on the portion of "reasonable" errors. If a predictor is worse + than a constant - it's practically worthless. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> import numpy as np + >>> from sklearn.linear_model import LinearRegression + >>> from sklearn.metrics import RecCurveDisplay + + >>> # from_estimator example + >>> X = np.array([[1], [2], [3], [4], [5]]) + >>> y = np.array([1, 2.5, 3, 4.5, 5]) + >>> estimator = LinearRegression().fit(X, y) + >>> RecCurveDisplay.from_estimator(estimator, X, y, loss='absolute') + <...> + >>> plt.show() + + >>> # from_predictions example + >>> y_pred = estimator.predict(X) + >>> RecCurveDisplay.from_predictions( + ... y, y_pred, loss='squared', name="My Model", plot_const_predictor=False + ... ) + <...> + >>> plt.show() + """ + + def __init__( + self, + *, + deviations, + accuracy, + estimator_name=None, + loss=None, + constant_predictor_deviations=None, + constant_predictor_accuracy=None, + constant_predictor_name=None, + ): + self.deviations = deviations + self.accuracy = accuracy + self.estimator_name = estimator_name if estimator_name is not None else "Model" + self.loss = loss + + self.constant_predictor_deviations = constant_predictor_deviations + self.constant_predictor_accuracy = constant_predictor_accuracy + self.constant_predictor_name = constant_predictor_name + + def plot( + self, + ax=None, + *, + name=None, + plot_const_predictor=True, + clip_max_const_error=True, + **kwargs, + ): + """Plot visualization. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + name : str, default=None + Name of REC curve for labeling. If `None`, use the name stored in + `estimator_name`. + plot_const_predictor : bool, default=True + If enabled, will plot the REC curve of the constant predictor. + clip_max_const_error : bool, default=True + If enabled, will clip the horizontal axis at the maximum error achieved + by the constant predictor. + **kwargs : dict + Keyword arguments to be passed to `matplotlib.pyplot.plot` for the + main REC curve. + + Returns + ------- + display : :class:`~sklearn.metrics.RecCurveDisplay` + Object that stores computed values. + """ + self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name) + + line_kwargs = {} + if "label" not in kwargs: # Allow user to override label + line_kwargs["label"] = name + line_kwargs = _validate_style_kwargs(line_kwargs, kwargs) + + if self.constant_predictor_deviations is not None: + max_const_error = max(self.constant_predictor_deviations) + elif clip_max_const_error: + raise ValueError( + "clip_max_const_error is True, but no constant deviations were given." + ) + + if clip_max_const_error: + mask = self.deviations <= max_const_error + self.line_, *_ = self.ax_.plot( + self.deviations[mask], self.accuracy[mask], **line_kwargs + ) + else: + self.line_, *_ = self.ax_.plot( + self.deviations, self.accuracy, **line_kwargs + ) + + # Plot constant predictor if its data exists and the flag is set + if ( + plot_const_predictor + and self.constant_predictor_deviations is not None + and self.constant_predictor_accuracy is not None + ): + cp_name = ( + self.constant_predictor_name + if self.constant_predictor_name + else "Constant Predictor" + ) + # Default style for constant predictor, can be overridden if needed + cp_kwargs = {"label": cp_name, "linestyle": "--"} + cp_kwargs = _validate_style_kwargs(cp_kwargs, kwargs) + self.constant_predictor_line_, *_ = self.ax_.plot( + self.constant_predictor_deviations, + self.constant_predictor_accuracy, + **cp_kwargs, + ) + + self.ax_.set_xlabel(f"Error Tolerance (Deviation - {self.loss} loss)") + self.ax_.set_ylabel("Accuracy (Fraction of samples)") + self.ax_.set_title("Regression Error Characteristic (REC) Curve") + self.ax_.legend(loc="lower right") + self.ax_.grid(True) + + return self + + @classmethod + def from_estimator( + cls, + estimator, + X, + y, + *, + loss="absolute", + constant_predictor=None, + plot_const_predictor=True, + clip_max_const_error=True, + name=None, + ax=None, + **kwargs, + ): + """Create a REC Curve display from an estimator. + + Parameters + ---------- + estimator : object + Fitted estimator or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a regressor. + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + y : array-like of shape (n_samples,) + Target values. + loss : {'absolute', 'squared'}, default='absolute' + The loss function to use for calculating deviations. + constant_predictor : {'mean', 'median', None}, default=None + The type of constant predictor to plot as a baseline. + If 'mean', uses the mean of `y_true`. + If 'median', uses the median of `y_true`. + If `None`, chooses 'mean' for 'squared' loss and 'median' for + 'absolute' loss, as these are the optimal constant predictors + for these losses. + plot_const_predictor : bool, default=True + Whether to compute and plot the REC curve for the constant predictor. + clip_max_const_error : bool, default=True + If `True`, the x-axis (error tolerance) will be cut off at the + maximum error achieved by the constant predictor. + name : str, default=None + Name for the REC curve. If `None`, the estimator's class name will be used. + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + **kwargs : dict + Keyword arguments to be passed to `matplotlib.pyplot.plot` for the + estimator's REC curve. + + Returns + ------- + display : :class:`~sklearn.metrics.RecCurveDisplay` + Object that stores computed values. + """ + check_is_fitted(estimator) + if not is_regressor(estimator): + raise TypeError(f"{estimator.__class__.__name__} is not a regressor.") + + y_pred = estimator.predict(X) + + if name is None: + name = estimator.__class__.__name__ + + # Call from_predictions with validated parameters + return cls.from_predictions( + y_true=y, + y_pred=y_pred, + loss=loss, + constant_predictor=constant_predictor, + plot_const_predictor=plot_const_predictor, + clip_max_const_error=clip_max_const_error, + name=name, + ax=ax, + **kwargs, + ) + + @classmethod + def from_predictions( + cls, + y_true, + y_pred, + *, + loss="absolute", + constant_predictor=None, + plot_const_predictor=True, + clip_max_const_error=True, + name=None, + ax=None, + **kwargs, + ): + """Plot REC curve given true and predicted values. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True target values. + y_pred : array-like of shape (n_samples,) or scalar + Estimated target values. The `rec_curve` function handles scalar `y_pred`. + loss : {'absolute', 'squared'}, default='absolute' + The loss function to use for calculating deviations. + constant_predictor : {'mean', 'median', None}, default=None + The type of constant predictor to plot as a baseline. + If 'mean', uses the mean of `y_true`. + If 'median', uses the median of `y_true`. + If `None`, chooses 'mean' for 'squared' loss and 'median' for + 'absolute' loss. + plot_const_predictor : bool, default=True + Whether to compute and plot the REC curve for the constant predictor. + clip_max_const_error : bool, default=True + If `True`, the x-axis (error tolerance) will be cut off at the + maximum error achieved by the constant predictor. This is only + effective if a constant predictor is computed. + name : str, default=None + Name for the REC curve. If `None`, will be "Model". + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + **kwargs : dict + Keyword arguments to be passed to `matplotlib.pyplot.plot` for the + main REC curve. + + Returns + ------- + display : :class:`~sklearn.metrics.RecCurveDisplay` + Object that stores computed values. + """ + main_devs, main_acc = rec_curve(y_true, y_pred, loss=loss) + + # Convert to NumPy arrays for consistent handling in display class, + # as rec_curve might accept and return array_api specific arrays. + y_true_np = np.asarray(y_true) + main_devs_np = np.asarray(main_devs) + main_acc_np = np.asarray(main_acc) + + # Determine the constant predictor type if not specified + actual_constant_predictor_type = constant_predictor + if actual_constant_predictor_type is None: + if loss == "squared": + actual_constant_predictor_type = "mean" + else: # loss == "absolute" + actual_constant_predictor_type = "median" + + # Compute constant predictor data if needed for plotting or cutoff + if plot_const_predictor or clip_max_const_error: + if actual_constant_predictor_type == "mean": + constant_value = np.mean(y_true_np) + cp_name_val = "Mean Predictor" + elif actual_constant_predictor_type == "median": + constant_value = np.median(y_true_np) + cp_name_val = "Median Predictor" + + cp_devs, cp_accs = rec_curve(y_true, constant_value, loss=loss) + cp_devs_np = np.asarray(cp_devs) + cp_accs_np = np.asarray(cp_accs) + else: + cp_devs_np, cp_accs_np, cp_name_val = None, None, None + + display_name = name if name is not None else "Model" + + obj = RecCurveDisplay( + deviations=main_devs_np, + accuracy=main_acc_np, + estimator_name=display_name, + loss=loss, # loss is already validated + constant_predictor_deviations=cp_devs_np, + constant_predictor_accuracy=cp_accs_np, + constant_predictor_name=cp_name_val, + ) + return obj.plot( + ax=ax, + plot_const_predictor=plot_const_predictor, + clip_max_const_error=clip_max_const_error, + **kwargs, + ) diff --git a/sklearn/metrics/_plot/tests/test_rec_curve_display.py b/sklearn/metrics/_plot/tests/test_rec_curve_display.py new file mode 100644 index 0000000000000..4413f26b287f0 --- /dev/null +++ b/sklearn/metrics/_plot/tests/test_rec_curve_display.py @@ -0,0 +1,154 @@ +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from sklearn import clone +from sklearn.compose import make_column_transformer +from sklearn.datasets import make_regression +from sklearn.exceptions import NotFittedError +from sklearn.linear_model import LinearRegression +from sklearn.metrics import RecCurveDisplay, rec_curve +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + + +@pytest.fixture(scope="module") +def regression_data(): + X, y = make_regression( + n_samples=200, + n_features=20, + n_informative=5, + noise=1, + random_state=42, + ) + return X, y + + +@pytest.mark.parametrize("plot_const_predictor", [True, False]) +@pytest.mark.parametrize("clip_max_const_error", [True, False]) +@pytest.mark.parametrize("constant_predictor", [None, "mean", "median"]) +@pytest.mark.parametrize("loss", [None, "squared", "absolute"]) +@pytest.mark.parametrize( + "constructor_name, default_name", + [ + ("from_estimator", "LinearRegression"), + ("from_predictions", "Model"), + ], +) +def test_roc_curve_display_plotting( + pyplot, + regression_data, + plot_const_predictor, + clip_max_const_error, + constant_predictor, + loss, + constructor_name, + default_name, +): + X, y = regression_data + + lr = LinearRegression() + lr.fit(X, y) + y_score = lr.predict(X) + + ctor_kwargs = dict( + plot_const_predictor=plot_const_predictor, + clip_max_const_error=clip_max_const_error, + ) + if loss is not None: + ctor_kwargs |= dict(loss=loss) + if constant_predictor is not None: + ctor_kwargs |= dict(constant_predictor=constant_predictor) + + if constructor_name == "from_estimator": + display = RecCurveDisplay.from_estimator(lr, X, y, **ctor_kwargs) + else: + display = RecCurveDisplay.from_predictions(y, y_score, **ctor_kwargs) + + if loss is None: + deviations, accuracy = rec_curve(y, y_score) + else: + deviations, accuracy = rec_curve(y, y_score, loss=loss) + + assert_allclose(display.deviations, deviations) + assert_allclose(display.accuracy, accuracy) + + if plot_const_predictor or clip_max_const_error: + const_predictor_labels = { + "mean": "Mean Predictor", + "median": "Median Predictor", + } + const_predictor_key = constant_predictor or ( + "mean" if loss == "squared" else "median" + ) + assert ( + display.constant_predictor_name + == const_predictor_labels[const_predictor_key] + ) + + const_value = np.mean(y) if const_predictor_key == "mean" else np.median(y) + if loss is None: + exp_const_deviations, exp_const_accuracy = rec_curve(y, const_value) + else: + exp_const_deviations, exp_const_accuracy = rec_curve( + y, const_value, loss=loss + ) + assert_allclose(display.constant_predictor_deviations, exp_const_deviations) + assert_allclose(display.constant_predictor_accuracy, exp_const_accuracy) + + assert display.estimator_name == default_name + assert display.loss == (loss or "absolute") + + import matplotlib as mpl + + assert isinstance(display.line_, mpl.lines.Line2D) + assert isinstance(display.ax_, mpl.axes.Axes) + assert isinstance(display.figure_, mpl.figure.Figure) + + if plot_const_predictor: + assert isinstance(display.constant_predictor_line_, mpl.lines.Line2D) + + expected_label = f"{default_name}" + assert display.line_.get_label() == expected_label + + expected_ylabel = "Accuracy (Fraction of samples)" + expected_xlabel = f"Error Tolerance (Deviation - {loss or 'absolute'} loss)" + + assert display.ax_.get_ylabel() == expected_ylabel + assert display.ax_.get_xlabel() == expected_xlabel + + +@pytest.mark.parametrize( + "clf", + [ + LinearRegression(), + make_pipeline(StandardScaler(), LinearRegression()), + make_pipeline( + make_column_transformer((StandardScaler(), [0, 1])), LinearRegression() + ), + ], +) +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +def test_rec_curve_display_complex_pipeline( + pyplot, regression_data, clf, constructor_name +): + """Check the behaviour with complex pipeline.""" + X, y = regression_data + + clf = clone(clf) + + if constructor_name == "from_estimator": + with pytest.raises(NotFittedError): + RecCurveDisplay.from_estimator(clf, X, y) + + clf.fit(X, y) + + if constructor_name == "from_estimator": + display = RecCurveDisplay.from_estimator(clf, X, y) + name = clf.__class__.__name__ + else: + display = RecCurveDisplay.from_predictions(y, y) + name = "Model" + + assert name in display.line_.get_label() + assert display.estimator_name == name diff --git a/sklearn/metrics/_regression_characteristic.py b/sklearn/metrics/_regression_characteristic.py new file mode 100644 index 0000000000000..1e64ab71689fd --- /dev/null +++ b/sklearn/metrics/_regression_characteristic.py @@ -0,0 +1,166 @@ +"""Regression Error Characteristic curve""" + +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause + +import numbers # For type checking Python scalars + +from ..utils import check_array, check_consistent_length +from ..utils._array_api import _find_matching_floating_dtype, get_namespace_and_device +from ..utils._param_validation import StrOptions, validate_params + + +@validate_params( + { + "y_true": ["array-like"], + "y_pred": ["array-like", numbers.Real], # Allow scalar or array-like + "loss": [StrOptions({"absolute", "squared"})], + }, + prefer_skip_nested_validation=True, +) +def rec_curve(y_true, y_pred, *, loss="absolute"): + """Compute Regression Error Characteristic (REC) curve. + + The REC curve evaluates regression models by plotting the error tolerance + (deviation) on the x-axis against the percentage of data points predicted + within that tolerance (accuracy) on the y-axis. It is the empirical + Cumulative Distribution Function (CDF) of the error. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True target values. + + y_pred : array-like of shape (n_samples,) or scalar + Estimated target values. If a scalar is provided, it is treated as + a constant prediction for all samples. + + loss : {'absolute', 'squared'}, default='absolute' + The type of loss to use for calculating deviations. + - 'absolute': Uses absolute deviations |y_true - y_pred|. + - 'squared': Uses squared deviations (y_true - y_pred)^2. + + Returns + ------- + deviations : ndarray + Sorted unique error tolerance values. These are the x-coordinates + for the REC curve. The array will start with 0.0 if the smallest + calculated error is greater than 0, representing zero tolerance. + + accuracy : ndarray + The corresponding accuracy (fraction of samples with error less than + or equal to the deviation). These are the y-coordinates for the REC + curve. The array will start with 0.0 if `deviations` starts with an + explicit 0.0. + + See Also + -------- + roc_curve : Compute Receiver Operating Characteristic (ROC) curve. + det_curve : Compute Detection Error Tradeoff (DET) curve. + + References + ---------- + .. [1] Bi, J., & Bennett, K. P. (2003). Regression error characteristic + curves. In Proceedings of the 20th International Conference on + Machine Learning (ICML-03) (pp. 43-50). + + Examples + -------- + >>> import numpy as np + >>> from sklearn.metrics import rec_curve + + >>> # example with absolute loss + >>> y_true = np.array([1, 2, 3, 4, 5, 6]) + >>> y_pred_model1 = np.array([1.1, 2.2, 2.8, 4.3, 4.8, 6.5]) + >>> deviations, accuracy = rec_curve(y_true, y_pred_model1, loss='absolute') + >>> deviations + array([0. , 0.1 , 0.2 , 0.3 , 0.5 ]) + >>> accuracy + array([0. , 0.16666667, 0.66666667, 0.83333333, 1. ]) + + >>> # Example with a scalar prediction (constant model) + >>> y_pred_scalar = 3.5 + >>> dev_scalar, acc_scalar = rec_curve(y_true, y_pred_scalar) + >>> dev_scalar + array([0. , 0.5, 1.5, 2.5]) + >>> acc_scalar + array([0. , 0.33333333, 0.66666667, 1. ]) + + >>> # Example with squared loss + >>> dev_sq, acc_sq = rec_curve(y_true, y_pred_model1, loss='squared') + >>> dev_sq # These are squared errors + array([0. , 0.01 , 0.04 , 0.09 , 0.25 ]) + >>> acc_sq + array([0. , 0.16666667, 0.66666667, 0.83333333, 1. ]) + """ + # Validate y_true and get the array namespace (xp) + y_true_array = check_array( + y_true, ensure_2d=False, dtype="numeric", ensure_all_finite=True + ) + xp, _, device = get_namespace_and_device(y_true_array) + + # Handle Python native scalars (int, float) + if isinstance(y_pred, numbers.Number): # numbers.Real covers int, float + y_pred_scalar_val = float(y_pred) + y_pred = xp.full( + y_true_array.shape, + fill_value=y_pred_scalar_val, + dtype=_find_matching_floating_dtype(y_true_array, xp=xp), + device=device, + ) + + # Handle array-like with a single prediction + if y_pred.size == 1: + y_pred = xp.squeeze(y_pred) + y_pred = xp.tile(y_pred, y_true_array.shape) + + y_pred_array = check_array( + y_pred, ensure_2d=False, dtype="numeric", ensure_all_finite=True + ) + check_consistent_length(y_true_array, y_pred_array) + + # cast to common floating point dtype. + common_dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp) + y_true_array = xp.astype(y_true_array, common_dtype, copy=False) + y_pred_array = xp.astype(y_pred_array, common_dtype, copy=False) + + # Calculate deviations based on the chosen loss + differences = y_true_array - y_pred_array + if loss == "absolute": + errors = xp.abs(differences) + else: # loss == "squared" + errors = xp.square(differences) + + n_samples = y_true_array.shape[0] + + # Handle empty input (no samples) + if n_samples == 0: + empty_float_array = xp.asarray( + [], dtype=xp.float64, device=xp.device(y_true_array) + ) + return empty_float_array, empty_float_array + + # compute deviations and counts in sorted order + deviations_calc, counts = xp.unique_counts(errors) + sort_order = xp.argsort(deviations_calc) + deviations_calc = deviations_calc[sort_order] + counts = counts[sort_order] + + # Calculate cumulative accuracy + cumulative_counts = xp.cumsum(counts) + accuracy_values = xp.astype(cumulative_counts, xp.float64) / float(n_samples) + + # Prepend (0,0) if the smallest error (first element of deviations_calc) is > 0.0, + # ensuring the curve starts from the origin of the plot unless + # there are samples with exactly zero error. + if deviations_calc[0] > 0.0: + zero_dev = xp.asarray([0.0], dtype=deviations_calc.dtype, device=device) + zero_acc = xp.asarray([0.0], dtype=accuracy_values.dtype, device=device) + + deviations_out = xp.concatenate((zero_dev, deviations_calc)) + accuracy_out = xp.concatenate((zero_acc, accuracy_values)) + else: + deviations_out = deviations_calc + accuracy_out = accuracy_values + + return deviations_out, accuracy_out diff --git a/sklearn/metrics/tests/test_regression.py b/sklearn/metrics/tests/test_regression.py index 5e90727583189..6f8755403e7e1 100644 --- a/sklearn/metrics/tests/test_regression.py +++ b/sklearn/metrics/tests/test_regression.py @@ -23,6 +23,7 @@ mean_tweedie_deviance, median_absolute_error, r2_score, + rec_curve, root_mean_squared_error, root_mean_squared_log_error, ) @@ -632,3 +633,35 @@ def test_pinball_loss_relation_with_mae(global_random_seed): mean_absolute_error(y_true, y_pred) == mean_pinball_loss(y_true, y_pred, alpha=0.5) * 2 ) + + +@pytest.mark.parametrize("constant_one_pred", [3.5, np.asarray(3.5), np.asarray([3.5])]) +def test_rec_curve_const_pred(constant_one_pred): + y_true = np.array([5, 6, 4, 3, 1, 2]) + + deviations, accuracy = rec_curve(y_true, constant_one_pred) + + assert_allclose(deviations, np.asarray([0.0, 0.5, 1.5, 2.5])) + assert_allclose(accuracy, np.asarray([0.0, 1 / 3.0, 2 / 3.0, 1.0])) + + +def test_rec_curve_array_pred(): + # four residuals of 1, and one residual of 0 + y_true = np.array([-1, 1, 2, -2, 0]) + y_pred = np.array([0, 2, 3, -1, 0]) + + deviations, accuracy = rec_curve(y_true, y_pred) + + assert_allclose(deviations, np.asarray([0.0, 1.0])) + assert_allclose(accuracy, np.asanyarray([0.2, 1.0])) + + +def test_rec_curve_squared_loss(): + # one residual of one, one residual of zero, three residuals of 2 + y_true = np.array([-1, 1, 2, -2, 0]) + y_pred = np.array([1, 2, 2, 0, -2]) + + deviations, accuracy = rec_curve(y_true, y_pred, loss="squared") + + assert_allclose(deviations, np.asarray([0.0, 1.0, 4.0])) + assert_allclose(accuracy, np.asanyarray([0.2, 0.4, 1.0])) diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 946c95186374b..9a01e080a6676 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -10,12 +10,8 @@ from .validation import _check_pos_label_consistency -class _BinaryClassifierCurveDisplayMixin: - """Mixin class to be used in Displays requiring a binary classifier. - - The aim of this class is to centralize some validations regarding the estimator and - the target and gather the response of the estimator. - """ +class _CurveDisplayMixin: + """Mixin class to be used in Displays plotting a tradeoff curve.""" def _validate_plot_params(self, *, ax=None, name=None): check_matplotlib_support(f"{self.__class__.__name__}.plot") @@ -27,6 +23,14 @@ def _validate_plot_params(self, *, ax=None, name=None): name = self.estimator_name if name is None else name return ax, ax.figure, name + +class _BinaryClassifierCurveDisplayMixin(_CurveDisplayMixin): + """Mixin class to be used in Displays requiring a binary classifier. + + The aim of this class is to centralize some validations regarding the estimator and + the target and gather the response of the estimator. + """ + @classmethod def _validate_and_get_response_values( cls, estimator, X, y, *, response_method="auto", pos_label=None, name=None