From 0a976c55d984e4d35afd6ff0bc2050bbb111d341 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sat, 17 May 2025 20:22:02 +0300 Subject: [PATCH 01/26] Initial version of REC curve --- sklearn/metrics/__init__.py | 1 + sklearn/metrics/_plot/rec_curve.py | 356 ++++++++++++++++++ sklearn/metrics/_regression_characteristic.py | 186 +++++++++ 3 files changed, 543 insertions(+) create mode 100644 sklearn/metrics/_plot/rec_curve.py create mode 100644 sklearn/metrics/_regression_characteristic.py diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index ce86525acc368..a0a2c614f19bc 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -65,6 +65,7 @@ root_mean_squared_error, root_mean_squared_log_error, ) +from ._regression_characteristic import rec_curve from ._scorer import check_scoring, get_scorer, get_scorer_names, make_scorer from .cluster import ( adjusted_mutual_info_score, diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py new file mode 100644 index 0000000000000..35bbd892a62a3 --- /dev/null +++ b/sklearn/metrics/_plot/rec_curve.py @@ -0,0 +1,356 @@ +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause + +import numpy as np + +from ...base import is_regressor # To check if estimator is a regressor +from ...metrics import rec_curve +from ...utils._optional_dependencies import check_matplotlib_support +from ...utils.validation import check_is_fitted + + +class RecCurveDisplay: + """Regression Error Characteristic (REC) Curve visualization. + + It is recommended to use :func:`~sklearn.metrics.RecCurveDisplay.from_estimator` + or :func:`~sklearn.metrics.RecCurveDisplay.from_predictions` to create + a visualizer. All parameters are stored as attributes. + + Read more in the :ref:`User Guide `. (Assuming this would be added) + + Parameters + ---------- + deviations : ndarray + Sorted unique error tolerance values (x-coordinates). + accuracy : ndarray + Corresponding accuracy values (y-coordinates). + estimator_name : str, default=None + Name of the estimator. If `None`, then the name will be `"Model"`. + loss : {'absolute', 'squared'}, default='absolute' + The loss function used to compute the REC curve. + constant_predictor_deviations : ndarray, default=None + Deviations for the constant predictor's REC curve. `None` if not plotted + or not computed. + constant_predictor_accuracy : ndarray, default=None + Accuracy for the constant predictor's REC curve. `None` if not plotted + or not computed. + constant_predictor_name : str, default=None + Name of the constant predictor. `None` if not plotted or not computed. + + Attributes + ---------- + line_ : matplotlib Artist + REC curve. + ax_ : matplotlib Axes + Axes with REC curve. + figure_ : matplotlib Figure + Figure containing the curve. + constant_predictor_line_ : matplotlib Artist, default=None + Constant predictor REC curve. Only defined if a constant predictor + was plotted. + + See Also + -------- + rec_curve : Compute Regression Error Characteristic (REC) curve. + RecCurveDisplay.from_estimator : Plot REC curve given an estimator and data. + RecCurveDisplay.from_predictions : Plot REC curve given true and predicted values. + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> import numpy as np + >>> from sklearn.linear_model import LinearRegression + >>> # Assuming rec_curve function is defined and available (mocked above for example) + >>> X = np.array([[1], [2], [3], [4], [5]]) + >>> y = np.array([1, 2.5, 3, 4.5, 5]) + >>> estimator = LinearRegression().fit(X, y) + >>> display = RecCurveDisplay.from_estimator(estimator, X, y, loss='absolute') + >>> # display.plot() # To show plot in interactive environment + >>> # plt.show() + >>> y_pred = estimator.predict(X) + >>> display_pred = RecCurveDisplay.from_predictions( + ... y, y_pred, loss='squared', name="My Model", plot_constant_predictor=False + ... ) + >>> # display_pred.plot() + >>> # plt.show() + """ + + def __init__( + self, + *, + deviations, + accuracy, + estimator_name=None, + max_const_error=None, + constant_predictor_deviations=None, + constant_predictor_accuracy=None, + constant_predictor_name=None, + ): + self.deviations = deviations + self.accuracy = accuracy + self.estimator_name = estimator_name if estimator_name is not None else "Model" + + self.constant_predictor_deviations = constant_predictor_deviations + self.constant_predictor_accuracy = constant_predictor_accuracy + self.constant_predictor_name = constant_predictor_name + + def plot( + self, + ax=None, + *, + name=None, + plot_const_predictor=True, + clip_max_const_error=True, + **kwargs, + ): + """Plot visualization. + + Parameters + ---------- + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + name : str, default=None + Name of REC curve for labeling. If `None`, use the name stored in + `estimator_name`. + **kwargs : dict + Keyword arguments to be passed to `matplotlib.pyplot.plot` for the + main REC curve. + + Returns + ------- + display : :class:`~sklearn.metrics.RecCurveDisplay` + Object that stores computed values. + """ + + check_matplotlib_support(f"{self.__class__.__name__}.plot") + import matplotlib.pyplot as plt + + if ax is None: + self.figure_, self.ax_ = plt.subplots() + else: + self.ax_ = ax + self.figure_ = self.ax_.figure + + plot_name = name if name is not None else self.estimator_name + line_kwargs = {} + if "label" not in kwargs: # Allow user to override label + line_kwargs["label"] = plot_name + line_kwargs.update(kwargs) + + if self.constant_predictor_deviations: + max_const_error = max(self.constant_predictor_deviations) + elif clip_max_const_error: + raise ValueError( + "clip_max_const_error is True, but no constant deviations were provided." + ) + + if clip_max_const_error: + mask = self.deviations <= max_const_error + self.line_, _ = self.ax_.plot( + self.deviations[mask], self.accuracy[mask], **line_kwargs + ) + else: + self.line_, _ = self.ax_.plot(self.deviations, self.accuracy, **line_kwargs) + + # Plot constant predictor if its data exists and the flag is set + if ( + plot_const_predictor + and self.constant_predictor_deviations is not None + and self.constant_predictor_accuracy is not None + ): + cp_name = ( + self.constant_predictor_name + if self.constant_predictor_name + else "Constant Predictor" + ) + # Default style for constant predictor, can be overridden if needed + cp_kwargs = {"label": cp_name, "linestyle": "--"} + self.constant_predictor_line_, _ = self.ax_.plot( + self.constant_predictor_deviations, + self.constant_predictor_accuracy, + **cp_kwargs, + ) + + self.ax_.set_xlabel(f"Error Tolerance (Deviation - {self.loss} loss)") + self.ax_.set_ylabel("Accuracy (Fraction of samples)") + self.ax_.set_title("Regression Error Characteristic (REC) Curve") + self.ax_.legend(loc="lower right") + self.ax_.grid(True) + + return self + + def from_estimator( + cls, + estimator, + X, + y, + *, + loss="absolute", + constant_predictor=None, + plot_constant_predictor=True, + clip_max_const_error=True, + name=None, + ax=None, + **kwargs, + ): + """Create a REC Curve display from an estimator. + + Parameters + ---------- + estimator : object + Fitted estimator or a fitted :class:`~sklearn.pipeline.Pipeline` + in which the last estimator is a regressor. + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Input values. + y : array-like of shape (n_samples,) + Target values. + sample_weight : array-like of shape (n_samples,), default=None + Sample weights. Currently not supported by the underlying `rec_curve` + function and will raise an error if provided. + loss : {'absolute', 'squared'}, default='absolute' + The loss function to use for calculating deviations. + constant_predictor : {'mean', 'median', None}, default=None + The type of constant predictor to plot as a baseline. + If 'mean', uses the mean of `y_true`. + If 'median', uses the median of `y_true`. + If `None`, chooses 'mean' for 'squared' loss and 'median' for + 'absolute' loss, as these are the optimal constant predictors + for these losses. + plot_constant_predictor : bool, default=True + Whether to compute and plot the REC curve for the constant predictor. + clip_max_const_error : bool, default=True + If `True`, the x-axis (error tolerance) will be cut off at the + maximum error achieved by the constant predictor. This is only + effective if a constant predictor is computed. + name : str, default=None + Name for the REC curve. If `None`, the estimator's class name will be used. + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + **kwargs : dict + Keyword arguments to be passed to `matplotlib.pyplot.plot` for the + estimator's REC curve. + + Returns + ------- + display : :class:`~sklearn.metrics.RecCurveDisplay` + Object that stores computed values. + """ + check_is_fitted(estimator) + if not is_regressor(estimator): + raise TypeError(f"{estimator.__class__.__name__} is not a regressor.") + + y_pred = estimator.predict(X) + + if name is None: + name = estimator.__class__.__name__ + + # Call from_predictions with validated parameters + return cls.from_predictions( + y_true=y, + y_pred=y_pred, + loss=loss, + constant_predictor=constant_predictor, + plot_constant_predictor=plot_constant_predictor, + clip_max_const_error=clip_max_const_error, + name=name, + ax=ax, + **kwargs, + ) + + def from_predictions( + cls, + y_true, + y_pred, + *, + loss="absolute", + constant_predictor=None, + plot_constant_predictor=True, + clip_max_const_error=True, + name=None, + ax=None, + **kwargs, + ): + """Plot REC curve given true and predicted values. + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True target values. + y_pred : array-like of shape (n_samples,) or scalar + Estimated target values. The `rec_curve` function handles scalar `y_pred`. + loss : {'absolute', 'squared'}, default='absolute' + The loss function to use for calculating deviations. + constant_predictor : {'mean', 'median', None}, default=None + The type of constant predictor to plot as a baseline. + If 'mean', uses the mean of `y_true`. + If 'median', uses the median of `y_true`. + If `None`, chooses 'mean' for 'squared' loss and 'median' for + 'absolute' loss. + plot_constant_predictor : bool, default=True + Whether to compute and plot the REC curve for the constant predictor. + clip_max_const_error : bool, default=True + If `True`, the x-axis (error tolerance) will be cut off at the + maximum error achieved by the constant predictor. This is only + effective if a constant predictor is computed. + name : str, default=None + Name for the REC curve. If `None`, will be "Model". + ax : matplotlib axes, default=None + Axes object to plot on. If `None`, a new figure and axes is created. + **kwargs : dict + Keyword arguments to be passed to `matplotlib.pyplot.plot` for the + main REC curve. + + Returns + ------- + display : :class:`~sklearn.metrics.RecCurveDisplay` + Object that stores computed values. + """ + main_devs, main_acc = rec_curve(y_true, y_pred, loss=loss) + + # Convert to NumPy arrays for consistent handling in display class, + # as rec_curve might accept and return array_api specific arrays. + y_true_np = np.asarray(y_true) + main_devs_np = np.asarray(main_devs) + main_acc_np = np.asarray(main_acc) + + # Determine the constant predictor type if not specified + actual_constant_predictor_type = constant_predictor + if actual_constant_predictor_type is None: + if loss == "squared": + actual_constant_predictor_type = "mean" + else: # loss == "absolute" + actual_constant_predictor_type = "median" + + # Compute constant predictor data if needed for plotting or cutoff + if plot_constant_predictor or clip_max_const_error: + if actual_constant_predictor_type == "mean": + constant_value = np.mean(y_true_np) + cp_name_val = "Mean Predictor" + elif actual_constant_predictor_type == "median": + constant_value = np.median(y_true_np) + cp_name_val = "Median Predictor" + # No else needed here as validate_params covers constant_predictor values + + cp_devs, cp_accs = rec_curve(y_true, constant_value, loss=loss) + cp_devs_np = np.asarray(cp_devs) + cp_accs_np = np.asarray(cp_accs) + else: + cp_devs_np, cp_accs_np, cp_name_val = None, None, None + + display_name = name if name is not None else "Model" + + obj = RecCurveDisplay( + deviations=main_devs_np, + accuracy=main_acc_np, + estimator_name=display_name, + loss=loss, # loss is already validated + constant_predictor_deviations=cp_devs_np, + constant_predictor_accuracy=cp_accs_np, + constant_predictor_name=cp_name_val, + ) + return obj.plot( + ax=ax, + plot_constant_predictor=plot_constant_predictor, + clip_max_const_error=clip_max_const_error, + **kwargs, + ) diff --git a/sklearn/metrics/_regression_characteristic.py b/sklearn/metrics/_regression_characteristic.py new file mode 100644 index 0000000000000..f0c9658a3c71a --- /dev/null +++ b/sklearn/metrics/_regression_characteristic.py @@ -0,0 +1,186 @@ +"""Regression Error Characteristic curve""" + +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause + +import numbers # For type checking Python scalars + +from ..utils import check_array, check_consistent_length +from ..utils._array_api import get_namespace_and_device # For array_api support +from ..utils._param_validation import StrOptions, validate_params + + +@validate_params( + { + "y_true": ["array-like"], + "y_pred": ["array-like", numbers.Real], # Allow scalar or array-like + "loss": [StrOptions({"absolute", "squared"})], + }, + prefer_skip_nested_validation=True, # Standard practice for functions doing further checks +) +def rec_curve(y_true, y_pred, *, loss="absolute"): + """Compute Regression Error Characteristic (REC) curve. + + The REC curve evaluates regression models by plotting the error tolerance + (deviation) on the x-axis against the percentage of data points predicted + within that tolerance (accuracy) on the y-axis. It is the empirical + Cumulative Distribution Function (CDF) of the error. + + This implementation is designed to be compatible with the array_api + standard and scikit-learn's utilities. + + Read more in the :ref:`User Guide `. (Assuming this would be added) + + Parameters + ---------- + y_true : array-like of shape (n_samples,) + True target values. + + y_pred : array-like of shape (n_samples,) or scalar + Estimated target values. If a scalar is provided, it is treated as + a constant prediction for all samples. + + loss : {'absolute', 'squared'}, default='absolute' + The type of loss to use for calculating deviations. + - 'absolute': Uses absolute deviations |y_true - y_pred|. + - 'squared': Uses squared deviations (y_true - y_pred)^2. + + Returns + ------- + deviations : ndarray + Sorted unique error tolerance values. These are the x-coordinates + for the REC curve. The array will start with 0.0 if the smallest + calculated error is greater than 0, representing zero tolerance. + + accuracy : ndarray + The corresponding accuracy (fraction of samples with error less than + or equal to the deviation). These are the y-coordinates for the REC + curve. The array will start with 0.0 if `deviations` starts with an + explicit 0.0. + + See Also + -------- + roc_curve : Compute Receiver Operating Characteristic (ROC) curve. + det_curve : Compute Detection Error Tradeoff (DET) curve. + + References + ---------- + .. [1] Bi, J., & Bennett, K. P. (2003). Regression error characteristic + curves. In Proceedings of the 20th International Conference on + Machine Learning (ICML-03) (pp. 43-50). + + Examples + -------- + >>> import numpy as np + >>> from sklearn.metrics import rec_curve # Assuming function is in sklearn.metrics + >>> y_true = np.array([1, 2, 3, 4, 5, 6]) + >>> y_pred_model1 = np.array([1.1, 2.2, 2.8, 4.3, 4.8, 6.5]) + >>> deviations, accuracy = rec_curve(y_true, y_pred_model1, loss='absolute') + >>> deviations + array([0. , 0.1 , 0.2 , 0.3 , 0.5 ]) + >>> accuracy + array([0. , 0.16666667, 0.66666667, 0.83333333, 1. ]) + + >>> # Example with a scalar prediction (constant model) + >>> y_pred_scalar = 3.5 + >>> dev_scalar, acc_scalar = rec_curve(y_true, y_pred_scalar) + >>> dev_scalar + array([0. , 0.5, 1.5, 2.5]) + >>> acc_scalar + array([0. , 0.33333333, 0.66666667, 1. ]) + + >>> # Example with squared loss + >>> dev_sq, acc_sq = rec_curve(y_true, y_pred_model1, loss='squared') + >>> dev_sq # These are squared errors + array([0. , 0.01 , 0.04 , 0.09 , 0.25 ]) + >>> acc_sq + array([0. , 0.16666667, 0.66666667, 0.83333333, 1. ]) + + >>> # For plotting with matplotlib: + >>> # import matplotlib.pyplot as plt + >>> # plt.figure() + >>> # plt.plot(deviations, accuracy, marker='.', label='Model 1 (Absolute Loss)') + >>> # plt.plot(dev_scalar, acc_scalar, marker='.', label='Constant Model (Absolute Loss)') + >>> # plt.xlabel("Error Tolerance (Deviation)") + >>> # plt.ylabel("Accuracy (Fraction of samples within tolerance)") + >>> # plt.title("Regression Error Characteristic (REC) Curve") + >>> # plt.legend() + >>> # plt.grid(True) + >>> # plt.show() + """ + # Validate y_true and get the array namespace (xp) + y_true_array = check_array( + y_true, ensure_2d=False, dtype="numeric", ensure_all_finite=True + ) + xp, _, device = get_namespace_and_device(y_true_array) + + # Handle y_pred: check if it's a scalar or array-like + # Python native scalars (int, float) + if isinstance(y_pred, numbers.Number): # numbers.Real covers int, float + y_pred_scalar_val = float(y_pred) + y_pred = xp.full( + y_true_array.shape, + fill_value=y_pred_scalar_val, + dtype=y_true_array.dtype, # Match y_true's dtype for consistency + device=device, + ) + y_pred_array = check_array( + y_pred, ensure_2d=False, dtype="numeric", ensure_all_finite=True + ) + check_consistent_length(y_true_array, y_pred_array) + + # Validate loss parameter + if loss not in ("absolute", "squared"): + raise ValueError( + f"loss type '{loss}' not supported, choose 'absolute' or 'squared'." + ) + + # Calculate deviations based on the chosen loss + # Since y_true_array and y_pred_array are finite, differences and errors will be finite. + differences = y_true_array - y_pred_array + if loss == "absolute": + errors = xp.abs(differences) + else: # loss == "squared" + errors = xp.square(differences) + + n_samples = y_true_array.shape[0] + + # Handle empty input (no samples) + if n_samples == 0: + empty_float_array = xp.asarray( + [], dtype=xp.float64, device=xp.device(y_true_array) + ) + return empty_float_array, empty_float_array + + # OPTIMIZED CDF CALCULATION: + # Get unique sorted error values (deviations_calc) and their counts. + # xp.unique_counts returns sorted unique values. + # Since errors are finite, deviations_calc will also be finite and non-empty if n_samples > 0. + deviations_calc, counts = xp.unique_counts(errors) + + # Calculate cumulative accuracy + cumulative_counts = xp.cumsum(counts) + # Ensure accuracy_values is float64 for consistency and precision. + accuracy_values = xp.astype(cumulative_counts, xp.float64) / float(n_samples) + + # Prepare output deviations and accuracy + # Prepend (0,0) if the smallest error (first element of deviations_calc) is > 0.0, + # ensuring the curve starts from the origin of the plot unless + # there are samples with exactly zero error. + # deviations_calc[0] is safe to access as n_samples > 0 implies deviations_calc is non-empty. + if deviations_calc[0] > 0.0: + # Create zero point with the correct dtype and device + # deviations_calc.dtype could be float32 or float64 depending on input error calculation. + zero_dev = xp.asarray([0.0], dtype=deviations_calc.dtype, device=device) + # accuracy_values is already float64. + zero_acc = xp.asarray([0.0], dtype=accuracy_values.dtype, device=device) + + deviations_out = xp.concatenate((zero_dev, deviations_calc)) + accuracy_out = xp.concatenate((zero_acc, accuracy_values)) + else: + # Smallest error is 0.0 (or less, though errors should be non-negative and finite) + # The curve naturally starts at (0, accuracy_for_zero_error) + deviations_out = deviations_calc + accuracy_out = accuracy_values + + return deviations_out, accuracy_out From 643339f371d41b6e01c5fe7e18d430bb0a3755d6 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sat, 17 May 2025 20:37:25 +0300 Subject: [PATCH 02/26] Added import to RecCurveDisplay --- sklearn/metrics/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sklearn/metrics/__init__.py b/sklearn/metrics/__init__.py index a0a2c614f19bc..b8281de7ae038 100644 --- a/sklearn/metrics/__init__.py +++ b/sklearn/metrics/__init__.py @@ -30,6 +30,7 @@ from ._plot.confusion_matrix import ConfusionMatrixDisplay from ._plot.det_curve import DetCurveDisplay from ._plot.precision_recall_curve import PrecisionRecallDisplay +from ._plot.rec_curve import RecCurveDisplay from ._plot.regression import PredictionErrorDisplay from ._plot.roc_curve import RocCurveDisplay from ._ranking import ( @@ -101,6 +102,7 @@ "DistanceMetric", "PrecisionRecallDisplay", "PredictionErrorDisplay", + "RecCurveDisplay", "RocCurveDisplay", "accuracy_score", "adjusted_mutual_info_score", @@ -169,6 +171,7 @@ "precision_score", "r2_score", "rand_score", + "rec_curve", "recall_score", "roc_auc_score", "roc_curve", From 87e19ed051cbfd28afc13c054ab37698d2bede10 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sat, 17 May 2025 20:57:44 +0300 Subject: [PATCH 03/26] Added tests and fixed bugs. --- sklearn/metrics/_plot/rec_curve.py | 2 +- sklearn/metrics/_regression_characteristic.py | 7 +++- sklearn/metrics/tests/test_regression.py | 33 +++++++++++++++++++ 3 files changed, 40 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index 35bbd892a62a3..57799d431a5cd 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -4,7 +4,7 @@ import numpy as np from ...base import is_regressor # To check if estimator is a regressor -from ...metrics import rec_curve +from ...metrics._regression_characteristic import rec_curve from ...utils._optional_dependencies import check_matplotlib_support from ...utils.validation import check_is_fitted diff --git a/sklearn/metrics/_regression_characteristic.py b/sklearn/metrics/_regression_characteristic.py index f0c9658a3c71a..f0fbc5674835b 100644 --- a/sklearn/metrics/_regression_characteristic.py +++ b/sklearn/metrics/_regression_characteristic.py @@ -114,7 +114,6 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): ) xp, _, device = get_namespace_and_device(y_true_array) - # Handle y_pred: check if it's a scalar or array-like # Python native scalars (int, float) if isinstance(y_pred, numbers.Number): # numbers.Real covers int, float y_pred_scalar_val = float(y_pred) @@ -124,6 +123,12 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): dtype=y_true_array.dtype, # Match y_true's dtype for consistency device=device, ) + + # array-like with a single prediction + if y_pred.size == 1: + y_pred = xp.squeeze(y_pred) + y_pred = xp.tile(y_pred, y_true_array.shape) + y_pred_array = check_array( y_pred, ensure_2d=False, dtype="numeric", ensure_all_finite=True ) diff --git a/sklearn/metrics/tests/test_regression.py b/sklearn/metrics/tests/test_regression.py index 5e90727583189..efe10c707c142 100644 --- a/sklearn/metrics/tests/test_regression.py +++ b/sklearn/metrics/tests/test_regression.py @@ -23,6 +23,7 @@ mean_tweedie_deviance, median_absolute_error, r2_score, + rec_curve, root_mean_squared_error, root_mean_squared_log_error, ) @@ -632,3 +633,35 @@ def test_pinball_loss_relation_with_mae(global_random_seed): mean_absolute_error(y_true, y_pred) == mean_pinball_loss(y_true, y_pred, alpha=0.5) * 2 ) + + +@pytest.mark.parametrize("constant_one_pred", [1.0, np.asarray(1.0), np.asarray([1.0])]) +def test_rec_curve_const_pred(constant_one_pred): + y_true = np.array([-1, 1, 2, -2, 0]) + + deviations, accuracy = rec_curve(y_true, constant_one_pred) + + assert_allclose(deviations, np.asarray([0.0, 1.0, 2.0, 3.0])) + assert_allclose(accuracy, np.asarray([0.2, 0.6, 0.8, 1.0])) + + +def test_rec_curve_array_pred(): + # four residuals of 1, and one residual of 0 + y_true = np.array([-1, 1, 2, -2, 0]) + y_pred = np.array([0, 2, 3, -1, 0]) + + deviations, accuracy = rec_curve(y_true, y_pred) + + assert_allclose(deviations, np.asarray([0.0, 1.0])) + assert_allclose(accuracy, np.asanyarray([0.2, 1.0])) + + +def test_rec_curve_squared_loss(): + # one residual of one, one residual of zero, three residuals of 2 + y_true = np.array([-1, 1, 2, -2, 0]) + y_pred = np.array([1, 2, 2, 0, -2]) + + deviations, accuracy = rec_curve(y_true, y_pred, loss="squared") + + assert_allclose(deviations, np.asarray([0.0, 1.0, 4.0])) + assert_allclose(accuracy, np.asanyarray([0.2, 0.4, 1.0])) From fbcd7f86e89c6d80e0fff43a74f5bece17530f40 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 08:48:43 +0300 Subject: [PATCH 04/26] Added examples of REC curve plots. --- .../plot_rec_curve_visualization.py | 85 +++++++++++++++++++ sklearn/metrics/_plot/rec_curve.py | 30 ++++--- 2 files changed, 103 insertions(+), 12 deletions(-) create mode 100644 examples/miscellaneous/plot_rec_curve_visualization.py diff --git a/examples/miscellaneous/plot_rec_curve_visualization.py b/examples/miscellaneous/plot_rec_curve_visualization.py new file mode 100644 index 0000000000000..f4bf3a2721467 --- /dev/null +++ b/examples/miscellaneous/plot_rec_curve_visualization.py @@ -0,0 +1,85 @@ +""" +============================================= +Regression Error Characteristic curve display +============================================= + +We illustrate how to display the regression error characteristic curve, and how to use it +to compare various regression models. +""" + +# Authors: The scikit-learn developers +# SPDX-License-Identifier: BSD-3-Clause + +# %% +# Data Loading and Preparation +# ---------------------------- +# +# Load the diabetes dataset. For simplicity, we only keep a single feature in the data. +# Then, we split the data and target into training and test sets. +from sklearn.datasets import load_diabetes +from sklearn.model_selection import train_test_split + +X, y = load_diabetes(return_X_y=True) +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=20, shuffle=False) + +# %% +# Linear regression model +# ----------------------- +# +# We create a linear regression model and fit it on the training after standard scaling +from sklearn.linear_model import LinearRegression +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + +lr_estimator = make_pipeline(StandardScaler(), LinearRegression()) +lr_estimator.fit(X_train, y_train) + +# %% +# Display REC curve using the test set. See that our linear regressor's REC curve dominates +# the curve of the default constant predictor - the median. This is an indicator that our model is doing something +# "reasonable". +from sklearn.metrics import RecCurveDisplay + +RecCurveDisplay.from_estimator(lr_estimator, X_test, y_test) + + +# %% +# Display REC curve of the model only, without the constant predictor. The linear regressor's REC curve also +# dominates the one of another constant predictor - the mean. +RecCurveDisplay.from_estimator(lr_estimator, X_test, y_test, constant_predictor="mean") + + +# %% Compare two REC curves of linear regression vs linear SVR. We can see different performance profiles +# of both estimators. The LinearSVR appears to have more samples with errors below 60, whereas +# linear regressor appears to have more samples than the SVR with errors below 20. This is despite both having +# almost the same summary metrics. +import matplotlib.pyplot as plt +from sklearn.svm import LinearSVR +from sklearn.metrics import root_mean_squared_error, mean_absolute_error + + +svr_estimator = make_pipeline(StandardScaler(), LinearSVR()) +svr_estimator.fit(X_train, y_train) + +pred_lr = lr_estimator.predict(X_test) +pred_svr = svr_estimator.predict(X_test) + +lr_metrics = f"RMSE = {root_mean_squared_error(pred_lr, y_test):.2f}, MAE = {mean_absolute_error(pred_lr, y_test):.2f}" +svr_metrics = f"RMSE = {root_mean_squared_error(pred_svr, y_test):.2f}, MAE = {mean_absolute_error(pred_svr, y_test):.2f}" + +fig, ax = plt.subplots() +RecCurveDisplay.from_predictions( + pred_lr, + y_test, + ax=ax, + name=f"Linear regression ({lr_metrics})", + plot_const_predictor=False, +) +RecCurveDisplay.from_predictions( + pred_svr, + y_test, + ax=ax, + name=f"Linear SVR ({svr_metrics})", + plot_const_predictor=False, +) +fig.show() diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index 57799d431a5cd..a99f328ae566c 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -69,7 +69,7 @@ class RecCurveDisplay: >>> # plt.show() >>> y_pred = estimator.predict(X) >>> display_pred = RecCurveDisplay.from_predictions( - ... y, y_pred, loss='squared', name="My Model", plot_constant_predictor=False + ... y, y_pred, loss='squared', name="My Model", plot_const_predictor=False ... ) >>> # display_pred.plot() >>> # plt.show() @@ -81,6 +81,7 @@ def __init__( deviations, accuracy, estimator_name=None, + loss=None, max_const_error=None, constant_predictor_deviations=None, constant_predictor_accuracy=None, @@ -89,6 +90,7 @@ def __init__( self.deviations = deviations self.accuracy = accuracy self.estimator_name = estimator_name if estimator_name is not None else "Model" + self.loss = loss self.constant_predictor_deviations = constant_predictor_deviations self.constant_predictor_accuracy = constant_predictor_accuracy @@ -137,7 +139,7 @@ def plot( line_kwargs["label"] = plot_name line_kwargs.update(kwargs) - if self.constant_predictor_deviations: + if self.constant_predictor_deviations is not None: max_const_error = max(self.constant_predictor_deviations) elif clip_max_const_error: raise ValueError( @@ -146,11 +148,13 @@ def plot( if clip_max_const_error: mask = self.deviations <= max_const_error - self.line_, _ = self.ax_.plot( + self.line_, *_ = self.ax_.plot( self.deviations[mask], self.accuracy[mask], **line_kwargs ) else: - self.line_, _ = self.ax_.plot(self.deviations, self.accuracy, **line_kwargs) + self.line_, *_ = self.ax_.plot( + self.deviations, self.accuracy, **line_kwargs + ) # Plot constant predictor if its data exists and the flag is set if ( @@ -165,7 +169,7 @@ def plot( ) # Default style for constant predictor, can be overridden if needed cp_kwargs = {"label": cp_name, "linestyle": "--"} - self.constant_predictor_line_, _ = self.ax_.plot( + self.constant_predictor_line_, *_ = self.ax_.plot( self.constant_predictor_deviations, self.constant_predictor_accuracy, **cp_kwargs, @@ -179,6 +183,7 @@ def plot( return self + @classmethod def from_estimator( cls, estimator, @@ -187,7 +192,7 @@ def from_estimator( *, loss="absolute", constant_predictor=None, - plot_constant_predictor=True, + plot_const_predictor=True, clip_max_const_error=True, name=None, ax=None, @@ -216,7 +221,7 @@ def from_estimator( If `None`, chooses 'mean' for 'squared' loss and 'median' for 'absolute' loss, as these are the optimal constant predictors for these losses. - plot_constant_predictor : bool, default=True + plot_const_predictor : bool, default=True Whether to compute and plot the REC curve for the constant predictor. clip_max_const_error : bool, default=True If `True`, the x-axis (error tolerance) will be cut off at the @@ -250,13 +255,14 @@ def from_estimator( y_pred=y_pred, loss=loss, constant_predictor=constant_predictor, - plot_constant_predictor=plot_constant_predictor, + plot_const_predictor=plot_const_predictor, clip_max_const_error=clip_max_const_error, name=name, ax=ax, **kwargs, ) + @classmethod def from_predictions( cls, y_true, @@ -264,7 +270,7 @@ def from_predictions( *, loss="absolute", constant_predictor=None, - plot_constant_predictor=True, + plot_const_predictor=True, clip_max_const_error=True, name=None, ax=None, @@ -286,7 +292,7 @@ def from_predictions( If 'median', uses the median of `y_true`. If `None`, chooses 'mean' for 'squared' loss and 'median' for 'absolute' loss. - plot_constant_predictor : bool, default=True + plot_const_predictor : bool, default=True Whether to compute and plot the REC curve for the constant predictor. clip_max_const_error : bool, default=True If `True`, the x-axis (error tolerance) will be cut off at the @@ -322,7 +328,7 @@ def from_predictions( actual_constant_predictor_type = "median" # Compute constant predictor data if needed for plotting or cutoff - if plot_constant_predictor or clip_max_const_error: + if plot_const_predictor or clip_max_const_error: if actual_constant_predictor_type == "mean": constant_value = np.mean(y_true_np) cp_name_val = "Mean Predictor" @@ -350,7 +356,7 @@ def from_predictions( ) return obj.plot( ax=ax, - plot_constant_predictor=plot_constant_predictor, + plot_const_predictor=plot_const_predictor, clip_max_const_error=clip_max_const_error, **kwargs, ) From 64e101be32b6cf15157b60a836d4e31ce15a8d69 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 09:41:52 +0300 Subject: [PATCH 05/26] Fixed REC curve example --- .../miscellaneous/plot_rec_curve_visualization.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/examples/miscellaneous/plot_rec_curve_visualization.py b/examples/miscellaneous/plot_rec_curve_visualization.py index f4bf3a2721467..64e4c4eac5032 100644 --- a/examples/miscellaneous/plot_rec_curve_visualization.py +++ b/examples/miscellaneous/plot_rec_curve_visualization.py @@ -40,19 +40,14 @@ # "reasonable". from sklearn.metrics import RecCurveDisplay -RecCurveDisplay.from_estimator(lr_estimator, X_test, y_test) - - -# %% -# Display REC curve of the model only, without the constant predictor. The linear regressor's REC curve also -# dominates the one of another constant predictor - the mean. -RecCurveDisplay.from_estimator(lr_estimator, X_test, y_test, constant_predictor="mean") +RecCurveDisplay.from_estimator(lr_estimator, X_test, y_test, name="Linear regression") # %% Compare two REC curves of linear regression vs linear SVR. We can see different performance profiles # of both estimators. The LinearSVR appears to have more samples with errors below 60, whereas # linear regressor appears to have more samples than the SVR with errors below 20. This is despite both having # almost the same summary metrics. +# NOTE - this is a toy example. To draw conclusions, we will need a larger test set. import matplotlib.pyplot as plt from sklearn.svm import LinearSVR from sklearn.metrics import root_mean_squared_error, mean_absolute_error @@ -69,15 +64,15 @@ fig, ax = plt.subplots() RecCurveDisplay.from_predictions( - pred_lr, y_test, + pred_lr, ax=ax, name=f"Linear regression ({lr_metrics})", plot_const_predictor=False, ) RecCurveDisplay.from_predictions( - pred_svr, y_test, + pred_svr, ax=ax, name=f"Linear SVR ({svr_metrics})", plot_const_predictor=False, From 7aebfca77a0479df928fe528a2b1f750fd680cf9 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 10:02:11 +0300 Subject: [PATCH 06/26] Improved the REC curve example --- .../plot_rec_curve_visualization.py | 32 +++++++++++-------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/examples/miscellaneous/plot_rec_curve_visualization.py b/examples/miscellaneous/plot_rec_curve_visualization.py index 64e4c4eac5032..75e4ab395c418 100644 --- a/examples/miscellaneous/plot_rec_curve_visualization.py +++ b/examples/miscellaneous/plot_rec_curve_visualization.py @@ -20,7 +20,9 @@ from sklearn.model_selection import train_test_split X, y = load_diabetes(return_X_y=True) -X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=20, shuffle=False) +X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=80, random_state=42, shuffle=True +) # %% # Linear regression model @@ -43,24 +45,26 @@ RecCurveDisplay.from_estimator(lr_estimator, X_test, y_test, name="Linear regression") -# %% Compare two REC curves of linear regression vs linear SVR. We can see different performance profiles -# of both estimators. The LinearSVR appears to have more samples with errors below 60, whereas -# linear regressor appears to have more samples than the SVR with errors below 20. This is despite both having -# almost the same summary metrics. -# NOTE - this is a toy example. To draw conclusions, we will need a larger test set. +# %% +# Compare two REC curves of linear regression vs ridge regression for polnyomial features. We can see that the curve +# of the ridge regressor with polynomial features dominates the one of the linear regressor. Meaning, for any error +# tolerance, the Poly-Ridge model has more samples below this tolerance. import matplotlib.pyplot as plt -from sklearn.svm import LinearSVR +from sklearn.linear_model import RidgeCV from sklearn.metrics import root_mean_squared_error, mean_absolute_error +from sklearn.preprocessing import PolynomialFeatures -svr_estimator = make_pipeline(StandardScaler(), LinearSVR()) -svr_estimator.fit(X_train, y_train) +ridge_poly_estimator = make_pipeline( + StandardScaler(), PolynomialFeatures(2, include_bias=False), RidgeCV() +) +ridge_poly_estimator.fit(X_train, y_train) pred_lr = lr_estimator.predict(X_test) -pred_svr = svr_estimator.predict(X_test) +pred_ridge_poly = ridge_poly_estimator.predict(X_test) lr_metrics = f"RMSE = {root_mean_squared_error(pred_lr, y_test):.2f}, MAE = {mean_absolute_error(pred_lr, y_test):.2f}" -svr_metrics = f"RMSE = {root_mean_squared_error(pred_svr, y_test):.2f}, MAE = {mean_absolute_error(pred_svr, y_test):.2f}" +ridge_poly_metrics = f"RMSE = {root_mean_squared_error(pred_ridge_poly, y_test):.2f}, MAE = {mean_absolute_error(pred_ridge_poly, y_test):.2f}" fig, ax = plt.subplots() RecCurveDisplay.from_predictions( @@ -72,9 +76,11 @@ ) RecCurveDisplay.from_predictions( y_test, - pred_svr, + pred_ridge_poly, ax=ax, - name=f"Linear SVR ({svr_metrics})", + name=f"Ridge Poly ({ridge_poly_metrics})", plot_const_predictor=False, ) fig.show() + +# %% From 10ec0e41fe14439029323cf107cd92710ab37120 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 14:43:26 +0300 Subject: [PATCH 07/26] Demonstrate using california housing --- .../plot_rec_curve_visualization.py | 35 ++++++++++++------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/examples/miscellaneous/plot_rec_curve_visualization.py b/examples/miscellaneous/plot_rec_curve_visualization.py index 75e4ab395c418..1459d98714605 100644 --- a/examples/miscellaneous/plot_rec_curve_visualization.py +++ b/examples/miscellaneous/plot_rec_curve_visualization.py @@ -16,12 +16,12 @@ # # Load the diabetes dataset. For simplicity, we only keep a single feature in the data. # Then, we split the data and target into training and test sets. -from sklearn.datasets import load_diabetes +from sklearn.datasets import fetch_california_housing from sklearn.model_selection import train_test_split -X, y = load_diabetes(return_X_y=True) +X, y = fetch_california_housing(return_X_y=True, as_frame=True) X_train, X_test, y_train, y_test = train_test_split( - X, y, test_size=80, random_state=42, shuffle=True + X, y, test_size=0.2, random_state=42, shuffle=True ) # %% @@ -49,22 +49,25 @@ # Compare two REC curves of linear regression vs ridge regression for polnyomial features. We can see that the curve # of the ridge regressor with polynomial features dominates the one of the linear regressor. Meaning, for any error # tolerance, the Poly-Ridge model has more samples below this tolerance. + import matplotlib.pyplot as plt -from sklearn.linear_model import RidgeCV from sklearn.metrics import root_mean_squared_error, mean_absolute_error -from sklearn.preprocessing import PolynomialFeatures +from sklearn.ensemble import HistGradientBoostingRegressor +from sklearn.neighbors import KNeighborsRegressor +hgbr_estimator = HistGradientBoostingRegressor() +hgbr_estimator.fit(X_train, y_train) -ridge_poly_estimator = make_pipeline( - StandardScaler(), PolynomialFeatures(2, include_bias=False), RidgeCV() -) -ridge_poly_estimator.fit(X_train, y_train) +knn_estimator = make_pipeline(StandardScaler(), KNeighborsRegressor()) +knn_estimator.fit(X_train, y_train) pred_lr = lr_estimator.predict(X_test) -pred_ridge_poly = ridge_poly_estimator.predict(X_test) +pred_hgbr = hgbr_estimator.predict(X_test) +pred_knn = knn_estimator.predict(X_test) lr_metrics = f"RMSE = {root_mean_squared_error(pred_lr, y_test):.2f}, MAE = {mean_absolute_error(pred_lr, y_test):.2f}" -ridge_poly_metrics = f"RMSE = {root_mean_squared_error(pred_ridge_poly, y_test):.2f}, MAE = {mean_absolute_error(pred_ridge_poly, y_test):.2f}" +hgbr_metrics = f"RMSE = {root_mean_squared_error(pred_hgbr, y_test):.2f}, MAE = {mean_absolute_error(pred_hgbr, y_test):.2f}" +knn_metrics = f"RMSE = {root_mean_squared_error(pred_knn, y_test):.2f}, MAE = {mean_absolute_error(pred_knn, y_test):.2f}" fig, ax = plt.subplots() RecCurveDisplay.from_predictions( @@ -76,11 +79,17 @@ ) RecCurveDisplay.from_predictions( y_test, - pred_ridge_poly, + pred_knn, ax=ax, - name=f"Ridge Poly ({ridge_poly_metrics})", + name=f"KNN ({knn_metrics})", plot_const_predictor=False, ) +RecCurveDisplay.from_predictions( + y_test, + pred_hgbr, + ax=ax, + name=f"Gradient Boosting Regressor ({hgbr_metrics})", +) fig.show() # %% From f978b992dc19828d378030bbfe8160a8d2c9d6ad Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 14:45:02 +0300 Subject: [PATCH 08/26] Added clarification to the comparison plot --- examples/miscellaneous/plot_rec_curve_visualization.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/miscellaneous/plot_rec_curve_visualization.py b/examples/miscellaneous/plot_rec_curve_visualization.py index 1459d98714605..4fcc9ef8856ff 100644 --- a/examples/miscellaneous/plot_rec_curve_visualization.py +++ b/examples/miscellaneous/plot_rec_curve_visualization.py @@ -46,10 +46,9 @@ # %% -# Compare two REC curves of linear regression vs ridge regression for polnyomial features. We can see that the curve -# of the ridge regressor with polynomial features dominates the one of the linear regressor. Meaning, for any error -# tolerance, the Poly-Ridge model has more samples below this tolerance. - +# Compare two REC curves of linear regression to KNN and histogram gradient boosting regressors. We can see a clear +# hierarchy here. KNN is strictly better than linear regression for any error tolerance, and HGBR is strictly better +# than both for any error tolerance. import matplotlib.pyplot as plt from sklearn.metrics import root_mean_squared_error, mean_absolute_error from sklearn.ensemble import HistGradientBoostingRegressor From 9c9f804fbcd0d2f8cee732477dd91150c477eb46 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 15:22:47 +0300 Subject: [PATCH 09/26] Fixed documentation --- .../plot_rec_curve_visualization.py | 35 ++++++++++++------- sklearn/metrics/_regression_characteristic.py | 28 +++------------ 2 files changed, 27 insertions(+), 36 deletions(-) diff --git a/examples/miscellaneous/plot_rec_curve_visualization.py b/examples/miscellaneous/plot_rec_curve_visualization.py index 4fcc9ef8856ff..04988aa3d50d7 100644 --- a/examples/miscellaneous/plot_rec_curve_visualization.py +++ b/examples/miscellaneous/plot_rec_curve_visualization.py @@ -3,8 +3,8 @@ Regression Error Characteristic curve display ============================================= -We illustrate how to display the regression error characteristic curve, and how to use it -to compare various regression models. +We illustrate how to display the regression error characteristic curve, and how to use +it to compare various regression models. """ # Authors: The scikit-learn developers @@ -37,21 +37,23 @@ lr_estimator.fit(X_train, y_train) # %% -# Display REC curve using the test set. See that our linear regressor's REC curve dominates -# the curve of the default constant predictor - the median. This is an indicator that our model is doing something -# "reasonable". +# Display REC curve using the test set. See that our linear regressor's REC curve +# dominates the curve of the default constant predictor - the median. This is an +# indicator that our model is doing something "reasonable". from sklearn.metrics import RecCurveDisplay RecCurveDisplay.from_estimator(lr_estimator, X_test, y_test, name="Linear regression") # %% -# Compare two REC curves of linear regression to KNN and histogram gradient boosting regressors. We can see a clear -# hierarchy here. KNN is strictly better than linear regression for any error tolerance, and HGBR is strictly better -# than both for any error tolerance. +# Compare two REC curves of linear regression to KNN and histogram gradient boosting +# regressors. We can see a clear hierarchy here. KNN is strictly better than linear +# regression for any error tolerance, and HGBR is strictly better than both for any +# error tolerance. import matplotlib.pyplot as plt -from sklearn.metrics import root_mean_squared_error, mean_absolute_error + from sklearn.ensemble import HistGradientBoostingRegressor +from sklearn.metrics import mean_absolute_error, root_mean_squared_error from sklearn.neighbors import KNeighborsRegressor hgbr_estimator = HistGradientBoostingRegressor() @@ -64,9 +66,18 @@ pred_hgbr = hgbr_estimator.predict(X_test) pred_knn = knn_estimator.predict(X_test) -lr_metrics = f"RMSE = {root_mean_squared_error(pred_lr, y_test):.2f}, MAE = {mean_absolute_error(pred_lr, y_test):.2f}" -hgbr_metrics = f"RMSE = {root_mean_squared_error(pred_hgbr, y_test):.2f}, MAE = {mean_absolute_error(pred_hgbr, y_test):.2f}" -knn_metrics = f"RMSE = {root_mean_squared_error(pred_knn, y_test):.2f}, MAE = {mean_absolute_error(pred_knn, y_test):.2f}" +lr_metrics = ( + f"RMSE = {root_mean_squared_error(pred_lr, y_test):.2f}, " + f"MAE = {mean_absolute_error(pred_lr, y_test):.2f}" +) +hgbr_metrics = ( + f"RMSE = {root_mean_squared_error(pred_hgbr, y_test):.2f}, " + f"MAE = {mean_absolute_error(pred_hgbr, y_test):.2f}" +) +knn_metrics = ( + f"RMSE = {root_mean_squared_error(pred_knn, y_test):.2f}, " + f"MAE = {mean_absolute_error(pred_knn, y_test):.2f}" +) fig, ax = plt.subplots() RecCurveDisplay.from_predictions( diff --git a/sklearn/metrics/_regression_characteristic.py b/sklearn/metrics/_regression_characteristic.py index f0fbc5674835b..5a43582b18a52 100644 --- a/sklearn/metrics/_regression_characteristic.py +++ b/sklearn/metrics/_regression_characteristic.py @@ -16,7 +16,7 @@ "y_pred": ["array-like", numbers.Real], # Allow scalar or array-like "loss": [StrOptions({"absolute", "squared"})], }, - prefer_skip_nested_validation=True, # Standard practice for functions doing further checks + prefer_skip_nested_validation=True, ) def rec_curve(y_true, y_pred, *, loss="absolute"): """Compute Regression Error Characteristic (REC) curve. @@ -26,11 +26,6 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): within that tolerance (accuracy) on the y-axis. It is the empirical Cumulative Distribution Function (CDF) of the error. - This implementation is designed to be compatible with the array_api - standard and scikit-learn's utilities. - - Read more in the :ref:`User Guide `. (Assuming this would be added) - Parameters ---------- y_true : array-like of shape (n_samples,) @@ -100,7 +95,7 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): >>> # import matplotlib.pyplot as plt >>> # plt.figure() >>> # plt.plot(deviations, accuracy, marker='.', label='Model 1 (Absolute Loss)') - >>> # plt.plot(dev_scalar, acc_scalar, marker='.', label='Constant Model (Absolute Loss)') + >>> # plt.plot(dev_scalar, acc_scalar, marker='.', label='Constant (Absolute Loss)') >>> # plt.xlabel("Error Tolerance (Deviation)") >>> # plt.ylabel("Accuracy (Fraction of samples within tolerance)") >>> # plt.title("Regression Error Characteristic (REC) Curve") @@ -114,7 +109,7 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): ) xp, _, device = get_namespace_and_device(y_true_array) - # Python native scalars (int, float) + # Handle Python native scalars (int, float) if isinstance(y_pred, numbers.Number): # numbers.Real covers int, float y_pred_scalar_val = float(y_pred) y_pred = xp.full( @@ -124,7 +119,7 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): device=device, ) - # array-like with a single prediction + # Handle array-like with a single prediction if y_pred.size == 1: y_pred = xp.squeeze(y_pred) y_pred = xp.tile(y_pred, y_true_array.shape) @@ -134,14 +129,7 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): ) check_consistent_length(y_true_array, y_pred_array) - # Validate loss parameter - if loss not in ("absolute", "squared"): - raise ValueError( - f"loss type '{loss}' not supported, choose 'absolute' or 'squared'." - ) - # Calculate deviations based on the chosen loss - # Since y_true_array and y_pred_array are finite, differences and errors will be finite. differences = y_true_array - y_pred_array if loss == "absolute": errors = xp.abs(differences) @@ -160,31 +148,23 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): # OPTIMIZED CDF CALCULATION: # Get unique sorted error values (deviations_calc) and their counts. # xp.unique_counts returns sorted unique values. - # Since errors are finite, deviations_calc will also be finite and non-empty if n_samples > 0. deviations_calc, counts = xp.unique_counts(errors) # Calculate cumulative accuracy cumulative_counts = xp.cumsum(counts) - # Ensure accuracy_values is float64 for consistency and precision. accuracy_values = xp.astype(cumulative_counts, xp.float64) / float(n_samples) # Prepare output deviations and accuracy # Prepend (0,0) if the smallest error (first element of deviations_calc) is > 0.0, # ensuring the curve starts from the origin of the plot unless # there are samples with exactly zero error. - # deviations_calc[0] is safe to access as n_samples > 0 implies deviations_calc is non-empty. if deviations_calc[0] > 0.0: - # Create zero point with the correct dtype and device - # deviations_calc.dtype could be float32 or float64 depending on input error calculation. zero_dev = xp.asarray([0.0], dtype=deviations_calc.dtype, device=device) - # accuracy_values is already float64. zero_acc = xp.asarray([0.0], dtype=accuracy_values.dtype, device=device) deviations_out = xp.concatenate((zero_dev, deviations_calc)) accuracy_out = xp.concatenate((zero_acc, accuracy_values)) else: - # Smallest error is 0.0 (or less, though errors should be non-negative and finite) - # The curve naturally starts at (0, accuracy_for_zero_error) deviations_out = deviations_calc accuracy_out = accuracy_values From 73458c57564b75c232a67b7b66776611ac7a23cc Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 15:25:25 +0300 Subject: [PATCH 10/26] Fixed more docs --- sklearn/metrics/_plot/rec_curve.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index a99f328ae566c..bca4f0ff08213 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -16,7 +16,6 @@ class RecCurveDisplay: or :func:`~sklearn.metrics.RecCurveDisplay.from_predictions` to create a visualizer. All parameters are stored as attributes. - Read more in the :ref:`User Guide `. (Assuming this would be added) Parameters ---------- @@ -52,8 +51,6 @@ class RecCurveDisplay: See Also -------- rec_curve : Compute Regression Error Characteristic (REC) curve. - RecCurveDisplay.from_estimator : Plot REC curve given an estimator and data. - RecCurveDisplay.from_predictions : Plot REC curve given true and predicted values. Examples -------- @@ -209,9 +206,6 @@ def from_estimator( Input values. y : array-like of shape (n_samples,) Target values. - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. Currently not supported by the underlying `rec_curve` - function and will raise an error if provided. loss : {'absolute', 'squared'}, default='absolute' The loss function to use for calculating deviations. constant_predictor : {'mean', 'median', None}, default=None @@ -225,8 +219,7 @@ def from_estimator( Whether to compute and plot the REC curve for the constant predictor. clip_max_const_error : bool, default=True If `True`, the x-axis (error tolerance) will be cut off at the - maximum error achieved by the constant predictor. This is only - effective if a constant predictor is computed. + maximum error achieved by the constant predictor. name : str, default=None Name for the REC curve. If `None`, the estimator's class name will be used. ax : matplotlib axes, default=None From 675b9f082cb4e4c6cea2dce0b35e76b99caf942f Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 15:32:14 +0300 Subject: [PATCH 11/26] Doc fixes --- sklearn/metrics/_plot/rec_curve.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index bca4f0ff08213..926064c320bf6 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -57,19 +57,16 @@ class RecCurveDisplay: >>> import matplotlib.pyplot as plt >>> import numpy as np >>> from sklearn.linear_model import LinearRegression - >>> # Assuming rec_curve function is defined and available (mocked above for example) >>> X = np.array([[1], [2], [3], [4], [5]]) >>> y = np.array([1, 2.5, 3, 4.5, 5]) >>> estimator = LinearRegression().fit(X, y) >>> display = RecCurveDisplay.from_estimator(estimator, X, y, loss='absolute') - >>> # display.plot() # To show plot in interactive environment - >>> # plt.show() + <...> >>> y_pred = estimator.predict(X) >>> display_pred = RecCurveDisplay.from_predictions( ... y, y_pred, loss='squared', name="My Model", plot_const_predictor=False ... ) - >>> # display_pred.plot() - >>> # plt.show() + <...> """ def __init__( From 8e2f942460104782b8a61f7adbf0a50432f1e191 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 15:33:15 +0300 Subject: [PATCH 12/26] More cosmetics --- sklearn/metrics/_plot/rec_curve.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index 926064c320bf6..f5a9bc02c58a5 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -137,7 +137,7 @@ def plot( max_const_error = max(self.constant_predictor_deviations) elif clip_max_const_error: raise ValueError( - "clip_max_const_error is True, but no constant deviations were provided." + "clip_max_const_error is True, but no constant deviations were given." ) if clip_max_const_error: @@ -325,7 +325,6 @@ def from_predictions( elif actual_constant_predictor_type == "median": constant_value = np.median(y_true_np) cp_name_val = "Median Predictor" - # No else needed here as validate_params covers constant_predictor values cp_devs, cp_accs = rec_curve(y_true, constant_value, loss=loss) cp_devs_np = np.asarray(cp_devs) From c035196bb313a11aa43e43d6e5af9aae6e74c83d Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 15:34:10 +0300 Subject: [PATCH 13/26] Removed unused argument from RecCurveDisplay ctor --- sklearn/metrics/_plot/rec_curve.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index f5a9bc02c58a5..6d31c686f4dc0 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -76,7 +76,6 @@ def __init__( accuracy, estimator_name=None, loss=None, - max_const_error=None, constant_predictor_deviations=None, constant_predictor_accuracy=None, constant_predictor_name=None, From 4cfbcb6b3c284c2d46b979ead2f2c1e59846ec21 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 21:45:48 +0300 Subject: [PATCH 14/26] Added changelog entries --- doc/whats_new/_contributors.rst | 2 ++ .../upcoming_changes/sklearn.metrics/31380.feature.rst | 8 ++++++++ 2 files changed, 10 insertions(+) create mode 100644 doc/whats_new/upcoming_changes/sklearn.metrics/31380.feature.rst diff --git a/doc/whats_new/_contributors.rst b/doc/whats_new/_contributors.rst index c74a2964e57bc..d56bdb9bc45ee 100644 --- a/doc/whats_new/_contributors.rst +++ b/doc/whats_new/_contributors.rst @@ -179,3 +179,5 @@ .. _Guillaume Lemaitre: https://github.com/glemaitre .. _Tim Head: https://betatim.github.io/ + +.. _Alex Shtoff: https://alexshtf.github.io/ \ No newline at end of file diff --git a/doc/whats_new/upcoming_changes/sklearn.metrics/31380.feature.rst b/doc/whats_new/upcoming_changes/sklearn.metrics/31380.feature.rst new file mode 100644 index 0000000000000..7aef8b28925f8 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.metrics/31380.feature.rst @@ -0,0 +1,8 @@ +The new functiom :func:`metrics.rec_curve` computes the Regression Error Characteristic +(REC) curve of error tolerances vs the fraction of samples below the tolerance (the CDF +of the regression errors). Suggested in the ICML'03 paper "Regression error +characteristic curves" by Bi and Bennett as the regression-variant of the ROC curve. + +The new class :class:`metrics.RecCurveDisplay` visualizes the REC curves. + +By :user:`Alex Shtoff ` \ No newline at end of file From ab6db062dfe187dcf3f3e154935fd2cb6e657df0 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 21:51:25 +0300 Subject: [PATCH 15/26] Improved REC plot docs --- sklearn/metrics/_plot/rec_curve.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index 6d31c686f4dc0..c76e86ffeb1e9 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -16,7 +16,6 @@ class RecCurveDisplay: or :func:`~sklearn.metrics.RecCurveDisplay.from_predictions` to create a visualizer. All parameters are stored as attributes. - Parameters ---------- deviations : ndarray @@ -38,8 +37,7 @@ class RecCurveDisplay: Attributes ---------- - line_ : matplotlib Artist - REC curve. + line_ : matplotlib Artist of REC curve. ax_ : matplotlib Axes Axes with REC curve. figure_ : matplotlib Figure @@ -48,6 +46,14 @@ class RecCurveDisplay: Constant predictor REC curve. Only defined if a constant predictor was plotted. + Remarks + ------- + This class uses the heuristic suggested in the REC paper to clip the x axis + at the maximum error achieved by some reasonable constant predictor, such as + the mean or the median. This is a heuristic to make plots more interpratable, + since we focus only on the portion of "reasonable" errors. If a predictor is worse + than a constant - it's practically worthless. + See Also -------- rec_curve : Compute Regression Error Characteristic (REC) curve. @@ -107,6 +113,11 @@ def plot( name : str, default=None Name of REC curve for labeling. If `None`, use the name stored in `estimator_name`. + plot_const_predictor : bool, default=True + If enabled, will plot the REC curve of the constant predictor. + clip_max_const_error: bool, default=True + If enabled, will clip the horizontal axis at the maximum error achieved + by the constant predictor. **kwargs : dict Keyword arguments to be passed to `matplotlib.pyplot.plot` for the main REC curve. From 7625c2c39435c82a64ae327c16afcdfd79d30473 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 21:57:42 +0300 Subject: [PATCH 16/26] Docstring fixes --- sklearn/metrics/_plot/rec_curve.py | 5 +++++ sklearn/metrics/_regression_characteristic.py | 16 +++------------- 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index c76e86ffeb1e9..e740477c6ee03 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -63,11 +63,16 @@ class RecCurveDisplay: >>> import matplotlib.pyplot as plt >>> import numpy as np >>> from sklearn.linear_model import LinearRegression + >>> from sklearn.metrics import RecCurveDisplay + + >>> # from_estimator example >>> X = np.array([[1], [2], [3], [4], [5]]) >>> y = np.array([1, 2.5, 3, 4.5, 5]) >>> estimator = LinearRegression().fit(X, y) >>> display = RecCurveDisplay.from_estimator(estimator, X, y, loss='absolute') <...> + + >>> # from_predictions example >>> y_pred = estimator.predict(X) >>> display_pred = RecCurveDisplay.from_predictions( ... y, y_pred, loss='squared', name="My Model", plot_const_predictor=False diff --git a/sklearn/metrics/_regression_characteristic.py b/sklearn/metrics/_regression_characteristic.py index 5a43582b18a52..07d61eb8a54a5 100644 --- a/sklearn/metrics/_regression_characteristic.py +++ b/sklearn/metrics/_regression_characteristic.py @@ -67,7 +67,9 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): Examples -------- >>> import numpy as np - >>> from sklearn.metrics import rec_curve # Assuming function is in sklearn.metrics + >>> from sklearn.metrics import rec_curve + + >>> # example with absolute loss >>> y_true = np.array([1, 2, 3, 4, 5, 6]) >>> y_pred_model1 = np.array([1.1, 2.2, 2.8, 4.3, 4.8, 6.5]) >>> deviations, accuracy = rec_curve(y_true, y_pred_model1, loss='absolute') @@ -90,18 +92,6 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): array([0. , 0.01 , 0.04 , 0.09 , 0.25 ]) >>> acc_sq array([0. , 0.16666667, 0.66666667, 0.83333333, 1. ]) - - >>> # For plotting with matplotlib: - >>> # import matplotlib.pyplot as plt - >>> # plt.figure() - >>> # plt.plot(deviations, accuracy, marker='.', label='Model 1 (Absolute Loss)') - >>> # plt.plot(dev_scalar, acc_scalar, marker='.', label='Constant (Absolute Loss)') - >>> # plt.xlabel("Error Tolerance (Deviation)") - >>> # plt.ylabel("Accuracy (Fraction of samples within tolerance)") - >>> # plt.title("Regression Error Characteristic (REC) Curve") - >>> # plt.legend() - >>> # plt.grid(True) - >>> # plt.show() """ # Validate y_true and get the array namespace (xp) y_true_array = check_array( From 34b52365a9102edd4e1f02cae4a95ae832d3e067 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 22:33:34 +0300 Subject: [PATCH 17/26] Bugfixes and doctest fixes --- sklearn/metrics/_plot/rec_curve.py | 7 ++++++- sklearn/metrics/_regression_characteristic.py | 9 +++++++-- sklearn/metrics/tests/test_regression.py | 8 ++++---- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index e740477c6ee03..3520f2109ae82 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -14,7 +14,10 @@ class RecCurveDisplay: It is recommended to use :func:`~sklearn.metrics.RecCurveDisplay.from_estimator` or :func:`~sklearn.metrics.RecCurveDisplay.from_predictions` to create - a visualizer. All parameters are stored as attributes. + a visualizer. + + For general information regarding `scikit-learn` visualization tools, see + the :ref:`Visualization Guide `. Parameters ---------- @@ -71,6 +74,7 @@ class RecCurveDisplay: >>> estimator = LinearRegression().fit(X, y) >>> display = RecCurveDisplay.from_estimator(estimator, X, y, loss='absolute') <...> + >>> plt.show() >>> # from_predictions example >>> y_pred = estimator.predict(X) @@ -78,6 +82,7 @@ class RecCurveDisplay: ... y, y_pred, loss='squared', name="My Model", plot_const_predictor=False ... ) <...> + >>> plt.show() """ def __init__( diff --git a/sklearn/metrics/_regression_characteristic.py b/sklearn/metrics/_regression_characteristic.py index 07d61eb8a54a5..04ed4c6e86433 100644 --- a/sklearn/metrics/_regression_characteristic.py +++ b/sklearn/metrics/_regression_characteristic.py @@ -6,7 +6,7 @@ import numbers # For type checking Python scalars from ..utils import check_array, check_consistent_length -from ..utils._array_api import get_namespace_and_device # For array_api support +from ..utils._array_api import get_namespace_and_device, _find_matching_floating_dtype from ..utils._param_validation import StrOptions, validate_params @@ -105,7 +105,7 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): y_pred = xp.full( y_true_array.shape, fill_value=y_pred_scalar_val, - dtype=y_true_array.dtype, # Match y_true's dtype for consistency + dtype=_find_matching_floating_dtype(y_true_array, xp=xp), device=device, ) @@ -119,6 +119,11 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): ) check_consistent_length(y_true_array, y_pred_array) + # cast to common floating point dtype. + common_dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp) + y_true_array = xp.astype(y_true_array, common_dtype, copy=False) + y_pred_array = xp.astype(y_pred_array, common_dtype, copy=False) + # Calculate deviations based on the chosen loss differences = y_true_array - y_pred_array if loss == "absolute": diff --git a/sklearn/metrics/tests/test_regression.py b/sklearn/metrics/tests/test_regression.py index efe10c707c142..6f8755403e7e1 100644 --- a/sklearn/metrics/tests/test_regression.py +++ b/sklearn/metrics/tests/test_regression.py @@ -635,14 +635,14 @@ def test_pinball_loss_relation_with_mae(global_random_seed): ) -@pytest.mark.parametrize("constant_one_pred", [1.0, np.asarray(1.0), np.asarray([1.0])]) +@pytest.mark.parametrize("constant_one_pred", [3.5, np.asarray(3.5), np.asarray([3.5])]) def test_rec_curve_const_pred(constant_one_pred): - y_true = np.array([-1, 1, 2, -2, 0]) + y_true = np.array([5, 6, 4, 3, 1, 2]) deviations, accuracy = rec_curve(y_true, constant_one_pred) - assert_allclose(deviations, np.asarray([0.0, 1.0, 2.0, 3.0])) - assert_allclose(accuracy, np.asarray([0.2, 0.6, 0.8, 1.0])) + assert_allclose(deviations, np.asarray([0.0, 0.5, 1.5, 2.5])) + assert_allclose(accuracy, np.asarray([0.0, 1 / 3.0, 2 / 3.0, 1.0])) def test_rec_curve_array_pred(): From 604531a55fd2148c635b05934c536d1902ef80be Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 22:35:20 +0300 Subject: [PATCH 18/26] Organize imports --- sklearn/metrics/_regression_characteristic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_regression_characteristic.py b/sklearn/metrics/_regression_characteristic.py index 04ed4c6e86433..390e92c77f16b 100644 --- a/sklearn/metrics/_regression_characteristic.py +++ b/sklearn/metrics/_regression_characteristic.py @@ -6,7 +6,7 @@ import numbers # For type checking Python scalars from ..utils import check_array, check_consistent_length -from ..utils._array_api import get_namespace_and_device, _find_matching_floating_dtype +from ..utils._array_api import _find_matching_floating_dtype, get_namespace_and_device from ..utils._param_validation import StrOptions, validate_params From a36bf45a4da6cfb61fffab8f2a0b3ad08c4dc9db Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 22:53:10 +0300 Subject: [PATCH 19/26] Fix doctest --- sklearn/metrics/_plot/rec_curve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index 3520f2109ae82..0330dcd043fa3 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -72,13 +72,13 @@ class RecCurveDisplay: >>> X = np.array([[1], [2], [3], [4], [5]]) >>> y = np.array([1, 2.5, 3, 4.5, 5]) >>> estimator = LinearRegression().fit(X, y) - >>> display = RecCurveDisplay.from_estimator(estimator, X, y, loss='absolute') + >>> RecCurveDisplay.from_estimator(estimator, X, y, loss='absolute') <...> >>> plt.show() >>> # from_predictions example >>> y_pred = estimator.predict(X) - >>> display_pred = RecCurveDisplay.from_predictions( + >>> RecCurveDisplay.from_predictions( ... y, y_pred, loss='squared', name="My Model", plot_const_predictor=False ... ) <...> From fb444188fc23c69133f4084ec5bb1c5d2deb39a8 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 23:13:24 +0300 Subject: [PATCH 20/26] Renamed Remarks to Notes section to conform to sklearn standards --- sklearn/metrics/_plot/rec_curve.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index 0330dcd043fa3..d8ae1121b4fff 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -49,8 +49,8 @@ class RecCurveDisplay: Constant predictor REC curve. Only defined if a constant predictor was plotted. - Remarks - ------- + Notes + ----- This class uses the heuristic suggested in the REC paper to clip the x axis at the maximum error achieved by some reasonable constant predictor, such as the mean or the median. This is a heuristic to make plots more interpratable, From 708fdf00e528120cd6aecc378fbb18352064ed44 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 23:15:54 +0300 Subject: [PATCH 21/26] Fixed typo in docstrings --- sklearn/metrics/_plot/rec_curve.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index d8ae1121b4fff..7259747591b47 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -125,7 +125,7 @@ def plot( `estimator_name`. plot_const_predictor : bool, default=True If enabled, will plot the REC curve of the constant predictor. - clip_max_const_error: bool, default=True + clip_max_const_error : bool, default=True If enabled, will clip the horizontal axis at the maximum error achieved by the constant predictor. **kwargs : dict From 1dbf56f75f8c9e14a06a8ec171ce4eebb6256817 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Sun, 18 May 2025 23:35:27 +0300 Subject: [PATCH 22/26] Fixed docstring section order --- sklearn/metrics/_plot/rec_curve.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index 7259747591b47..705b439dc8a93 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -49,6 +49,10 @@ class RecCurveDisplay: Constant predictor REC curve. Only defined if a constant predictor was plotted. + See Also + -------- + rec_curve : Compute Regression Error Characteristic (REC) curve. + Notes ----- This class uses the heuristic suggested in the REC paper to clip the x axis @@ -57,10 +61,6 @@ class RecCurveDisplay: since we focus only on the portion of "reasonable" errors. If a predictor is worse than a constant - it's practically worthless. - See Also - -------- - rec_curve : Compute Regression Error Characteristic (REC) curve. - Examples -------- >>> import matplotlib.pyplot as plt From 85515f8d941aa4954dbe7be82459bd2f76c336cc Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Mon, 19 May 2025 09:08:12 +0300 Subject: [PATCH 23/26] Added REC curve display tests --- .../_plot/tests/test_rec_curve_display.py | 154 ++++++++++++++++++ 1 file changed, 154 insertions(+) create mode 100644 sklearn/metrics/_plot/tests/test_rec_curve_display.py diff --git a/sklearn/metrics/_plot/tests/test_rec_curve_display.py b/sklearn/metrics/_plot/tests/test_rec_curve_display.py new file mode 100644 index 0000000000000..4413f26b287f0 --- /dev/null +++ b/sklearn/metrics/_plot/tests/test_rec_curve_display.py @@ -0,0 +1,154 @@ +import numpy as np +import pytest +from numpy.testing import assert_allclose + +from sklearn import clone +from sklearn.compose import make_column_transformer +from sklearn.datasets import make_regression +from sklearn.exceptions import NotFittedError +from sklearn.linear_model import LinearRegression +from sklearn.metrics import RecCurveDisplay, rec_curve +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler + + +@pytest.fixture(scope="module") +def regression_data(): + X, y = make_regression( + n_samples=200, + n_features=20, + n_informative=5, + noise=1, + random_state=42, + ) + return X, y + + +@pytest.mark.parametrize("plot_const_predictor", [True, False]) +@pytest.mark.parametrize("clip_max_const_error", [True, False]) +@pytest.mark.parametrize("constant_predictor", [None, "mean", "median"]) +@pytest.mark.parametrize("loss", [None, "squared", "absolute"]) +@pytest.mark.parametrize( + "constructor_name, default_name", + [ + ("from_estimator", "LinearRegression"), + ("from_predictions", "Model"), + ], +) +def test_roc_curve_display_plotting( + pyplot, + regression_data, + plot_const_predictor, + clip_max_const_error, + constant_predictor, + loss, + constructor_name, + default_name, +): + X, y = regression_data + + lr = LinearRegression() + lr.fit(X, y) + y_score = lr.predict(X) + + ctor_kwargs = dict( + plot_const_predictor=plot_const_predictor, + clip_max_const_error=clip_max_const_error, + ) + if loss is not None: + ctor_kwargs |= dict(loss=loss) + if constant_predictor is not None: + ctor_kwargs |= dict(constant_predictor=constant_predictor) + + if constructor_name == "from_estimator": + display = RecCurveDisplay.from_estimator(lr, X, y, **ctor_kwargs) + else: + display = RecCurveDisplay.from_predictions(y, y_score, **ctor_kwargs) + + if loss is None: + deviations, accuracy = rec_curve(y, y_score) + else: + deviations, accuracy = rec_curve(y, y_score, loss=loss) + + assert_allclose(display.deviations, deviations) + assert_allclose(display.accuracy, accuracy) + + if plot_const_predictor or clip_max_const_error: + const_predictor_labels = { + "mean": "Mean Predictor", + "median": "Median Predictor", + } + const_predictor_key = constant_predictor or ( + "mean" if loss == "squared" else "median" + ) + assert ( + display.constant_predictor_name + == const_predictor_labels[const_predictor_key] + ) + + const_value = np.mean(y) if const_predictor_key == "mean" else np.median(y) + if loss is None: + exp_const_deviations, exp_const_accuracy = rec_curve(y, const_value) + else: + exp_const_deviations, exp_const_accuracy = rec_curve( + y, const_value, loss=loss + ) + assert_allclose(display.constant_predictor_deviations, exp_const_deviations) + assert_allclose(display.constant_predictor_accuracy, exp_const_accuracy) + + assert display.estimator_name == default_name + assert display.loss == (loss or "absolute") + + import matplotlib as mpl + + assert isinstance(display.line_, mpl.lines.Line2D) + assert isinstance(display.ax_, mpl.axes.Axes) + assert isinstance(display.figure_, mpl.figure.Figure) + + if plot_const_predictor: + assert isinstance(display.constant_predictor_line_, mpl.lines.Line2D) + + expected_label = f"{default_name}" + assert display.line_.get_label() == expected_label + + expected_ylabel = "Accuracy (Fraction of samples)" + expected_xlabel = f"Error Tolerance (Deviation - {loss or 'absolute'} loss)" + + assert display.ax_.get_ylabel() == expected_ylabel + assert display.ax_.get_xlabel() == expected_xlabel + + +@pytest.mark.parametrize( + "clf", + [ + LinearRegression(), + make_pipeline(StandardScaler(), LinearRegression()), + make_pipeline( + make_column_transformer((StandardScaler(), [0, 1])), LinearRegression() + ), + ], +) +@pytest.mark.parametrize("constructor_name", ["from_estimator", "from_predictions"]) +def test_rec_curve_display_complex_pipeline( + pyplot, regression_data, clf, constructor_name +): + """Check the behaviour with complex pipeline.""" + X, y = regression_data + + clf = clone(clf) + + if constructor_name == "from_estimator": + with pytest.raises(NotFittedError): + RecCurveDisplay.from_estimator(clf, X, y) + + clf.fit(X, y) + + if constructor_name == "from_estimator": + display = RecCurveDisplay.from_estimator(clf, X, y) + name = clf.__class__.__name__ + else: + display = RecCurveDisplay.from_predictions(y, y) + name = "Model" + + assert name in display.line_.get_label() + assert display.estimator_name == name From 9173f46675c177ca915fe5ae316d5bed6aaad644 Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Mon, 19 May 2025 10:05:16 +0300 Subject: [PATCH 24/26] Refactor code to reduce duplication in plotting --- sklearn/metrics/_plot/rec_curve.py | 20 ++++++-------------- sklearn/utils/_plotting.py | 16 ++++++++++------ 2 files changed, 16 insertions(+), 20 deletions(-) diff --git a/sklearn/metrics/_plot/rec_curve.py b/sklearn/metrics/_plot/rec_curve.py index 705b439dc8a93..cbdbce557d1bc 100644 --- a/sklearn/metrics/_plot/rec_curve.py +++ b/sklearn/metrics/_plot/rec_curve.py @@ -5,11 +5,11 @@ from ...base import is_regressor # To check if estimator is a regressor from ...metrics._regression_characteristic import rec_curve -from ...utils._optional_dependencies import check_matplotlib_support +from ...utils._plotting import _CurveDisplayMixin, _validate_style_kwargs from ...utils.validation import check_is_fitted -class RecCurveDisplay: +class RecCurveDisplay(_CurveDisplayMixin): """Regression Error Characteristic (REC) Curve visualization. It is recommended to use :func:`~sklearn.metrics.RecCurveDisplay.from_estimator` @@ -137,21 +137,12 @@ def plot( display : :class:`~sklearn.metrics.RecCurveDisplay` Object that stores computed values. """ + self.ax_, self.figure_, name = self._validate_plot_params(ax=ax, name=name) - check_matplotlib_support(f"{self.__class__.__name__}.plot") - import matplotlib.pyplot as plt - - if ax is None: - self.figure_, self.ax_ = plt.subplots() - else: - self.ax_ = ax - self.figure_ = self.ax_.figure - - plot_name = name if name is not None else self.estimator_name line_kwargs = {} if "label" not in kwargs: # Allow user to override label - line_kwargs["label"] = plot_name - line_kwargs.update(kwargs) + line_kwargs["label"] = name + line_kwargs = _validate_style_kwargs(line_kwargs, kwargs) if self.constant_predictor_deviations is not None: max_const_error = max(self.constant_predictor_deviations) @@ -183,6 +174,7 @@ def plot( ) # Default style for constant predictor, can be overridden if needed cp_kwargs = {"label": cp_name, "linestyle": "--"} + cp_kwargs = _validate_style_kwargs(cp_kwargs, kwargs) self.constant_predictor_line_, *_ = self.ax_.plot( self.constant_predictor_deviations, self.constant_predictor_accuracy, diff --git a/sklearn/utils/_plotting.py b/sklearn/utils/_plotting.py index 946c95186374b..9a01e080a6676 100644 --- a/sklearn/utils/_plotting.py +++ b/sklearn/utils/_plotting.py @@ -10,12 +10,8 @@ from .validation import _check_pos_label_consistency -class _BinaryClassifierCurveDisplayMixin: - """Mixin class to be used in Displays requiring a binary classifier. - - The aim of this class is to centralize some validations regarding the estimator and - the target and gather the response of the estimator. - """ +class _CurveDisplayMixin: + """Mixin class to be used in Displays plotting a tradeoff curve.""" def _validate_plot_params(self, *, ax=None, name=None): check_matplotlib_support(f"{self.__class__.__name__}.plot") @@ -27,6 +23,14 @@ def _validate_plot_params(self, *, ax=None, name=None): name = self.estimator_name if name is None else name return ax, ax.figure, name + +class _BinaryClassifierCurveDisplayMixin(_CurveDisplayMixin): + """Mixin class to be used in Displays requiring a binary classifier. + + The aim of this class is to centralize some validations regarding the estimator and + the target and gather the response of the estimator. + """ + @classmethod def _validate_and_get_response_values( cls, estimator, X, y, *, response_method="auto", pos_label=None, name=None From 83f3dc925a12799ff2171c90e2c8cc0f01133dcb Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Mon, 19 May 2025 15:13:09 +0300 Subject: [PATCH 25/26] Ensured deviations are sorted, since xp.unique_counts does not ensure that --- sklearn/metrics/_regression_characteristic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sklearn/metrics/_regression_characteristic.py b/sklearn/metrics/_regression_characteristic.py index 390e92c77f16b..1e64ab71689fd 100644 --- a/sklearn/metrics/_regression_characteristic.py +++ b/sklearn/metrics/_regression_characteristic.py @@ -140,16 +140,16 @@ def rec_curve(y_true, y_pred, *, loss="absolute"): ) return empty_float_array, empty_float_array - # OPTIMIZED CDF CALCULATION: - # Get unique sorted error values (deviations_calc) and their counts. - # xp.unique_counts returns sorted unique values. + # compute deviations and counts in sorted order deviations_calc, counts = xp.unique_counts(errors) + sort_order = xp.argsort(deviations_calc) + deviations_calc = deviations_calc[sort_order] + counts = counts[sort_order] # Calculate cumulative accuracy cumulative_counts = xp.cumsum(counts) accuracy_values = xp.astype(cumulative_counts, xp.float64) / float(n_samples) - # Prepare output deviations and accuracy # Prepend (0,0) if the smallest error (first element of deviations_calc) is > 0.0, # ensuring the curve starts from the origin of the plot unless # there are samples with exactly zero error. From bd3649b07a62531b59a782ffa50bdda3623504fd Mon Sep 17 00:00:00 2001 From: Alexander Shtoff Date: Mon, 19 May 2025 17:55:23 +0300 Subject: [PATCH 26/26] retrigger build