Skip to content

Commit e92dd40

Browse files
virchanOmarManzooradrinjalali
authored
ENH add support for array API to various metric (scikit-learn#29709)
Co-authored-by: Omar Salman <omar.salman2007@gmail.com> Co-authored-by: Adrin Jalali <adrin.jalali@gmail.com>
1 parent fde6f2d commit e92dd40

File tree

6 files changed

+136
-39
lines changed

6 files changed

+136
-39
lines changed

doc/modules/array_api.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ Metrics
121121
- :func:`sklearn.metrics.mean_gamma_deviance`
122122
- :func:`sklearn.metrics.mean_poisson_deviance` (requires `enabling array API support for SciPy <https://docs.scipy.org/doc/scipy/dev/api-dev/array_api.html#using-array-api-standard-support>`_)
123123
- :func:`sklearn.metrics.mean_squared_error`
124+
- :func:`sklearn.metrics.mean_squared_log_error`
124125
- :func:`sklearn.metrics.mean_tweedie_deviance`
125126
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel`
126127
- :func:`sklearn.metrics.pairwise.chi2_kernel`
@@ -134,6 +135,8 @@ Metrics
134135
- :func:`sklearn.metrics.pairwise.rbf_kernel` (see :ref:`device_support_for_float64`)
135136
- :func:`sklearn.metrics.pairwise.sigmoid_kernel`
136137
- :func:`sklearn.metrics.r2_score`
138+
- :func:`sklearn.metrics.root_mean_squared_error`
139+
- :func:`sklearn.metrics.root_mean_squared_log_error`
137140
- :func:`sklearn.metrics.zero_one_loss`
138141

139142
Tools

doc/whats_new/v1.6.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,10 @@ See :ref:`array_api` for more details.
5555
- :func:`sklearn.metrics.mean_gamma_deviance` :pr:`29239` by :user:`Emily Chen <EmilyXinyi>`;
5656
- :func:`sklearn.metrics.mean_poisson_deviance` :pr:`29227` by :user:`Emily Chen <EmilyXinyi>`;
5757
- :func:`sklearn.metrics.mean_squared_error` :pr:`29142` by :user:`Yaroslav Korobko <Tialo>`;
58+
- :func:`sklearn.metrics.mean_squared_log_error` :pr:`29709` by :user:`Virgil Chan <virchan>`;
5859
- :func:`sklearn.metrics.mean_tweedie_deviance` :pr:`28106` by :user:`Thomas Li <lithomas1>`;
60+
- :func:`sklearn.metrics.root_mean_squared_error` :pr:`29709` by :user:`Virgil Chan <virchan>`;
61+
- :func:`sklearn.metrics.root_mean_squared_log_error` :pr:`29709` by :user:`Virgil Chan <virchan>`;
5962
- :func:`sklearn.metrics.pairwise.additive_chi2_kernel` :pr:`29144` by :user:`Yaroslav Korobko <Tialo>`;
6063
- :func:`sklearn.metrics.pairwise.chi2_kernel` :pr:`29267` by :user:`Yaroslav Korobko <Tialo>`;
6164
- :func:`sklearn.metrics.pairwise.cosine_distances` :pr:`29265` by :user:`Emily Chen <EmilyXinyi>`;
@@ -313,6 +316,19 @@ Changelog
313316
is renamed into `ensure_all_finite`. `force_all_finite` will be removed in 1.8.
314317
:pr:`29404` by :user:`Jérémie du Boisberranger <jeremiedb>`.
315318

319+
- |Fix| the functions :func:`metrics.mean_squared_log_error` and
320+
:func:`metrics.root_mean_squared_log_error` now check whether
321+
the inputs are within the correct domain for the function
322+
:math:`y=\log(1+x)`, rather than :math:`y=\log(x)`.
323+
:pr:`29709` by :user:`Virgil Chan <virchan>`.
324+
325+
- |Fix| the functions :func:`metrics.mean_absolute_error`,
326+
:func:`metrics.mean_absolute_percentage_error`, :func:`metrics.mean_squared_error`
327+
and :func:`metrics.root_mean_squared_error` now explicitly check whether a scalar
328+
will be returned when `multioutput=uniform_average`.
329+
:pr:`29709` by :user:`Virgil Chan <virchan>`.
330+
331+
316332
:mod:`sklearn.model_selection`
317333
..............................
318334

sklearn/metrics/_regression.py

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -217,13 +217,12 @@ def mean_absolute_error(
217217
multioutput = None
218218

219219
# Average across the outputs (if needed).
220+
# The second call to `_average` should always return
221+
# a scalar array that we convert to a Python float to
222+
# consistently return the same eager evaluated value.
223+
# Therefore, `axis=None`.
220224
mean_absolute_error = _average(output_errors, weights=multioutput)
221225

222-
# Since `y_pred.ndim <= 2` and `y_true.ndim <= 2`, the second call to _average
223-
# should always return a scalar array that we convert to a Python float to
224-
# consistently return the same eager evaluated value, irrespective of the
225-
# Array API implementation.
226-
assert mean_absolute_error.shape == ()
227226
return float(mean_absolute_error)
228227

229228

@@ -416,8 +415,13 @@ def mean_absolute_percentage_error(
416415
# pass None as weights to _average: uniform mean
417416
multioutput = None
418417

418+
# Average across the outputs (if needed).
419+
# The second call to `_average` should always return
420+
# a scalar array that we convert to a Python float to
421+
# consistently return the same eager evaluated value.
422+
# Therefore, `axis=None`.
419423
mean_absolute_percentage_error = _average(output_errors, weights=multioutput)
420-
assert mean_absolute_percentage_error.shape == ()
424+
421425
return float(mean_absolute_percentage_error)
422426

423427

@@ -524,12 +528,16 @@ def mean_squared_error(
524528
if multioutput == "raw_values":
525529
return output_errors
526530
elif multioutput == "uniform_average":
527-
# pass None as weights to np.average: uniform mean
531+
# pass None as weights to _average: uniform mean
528532
multioutput = None
529533

530-
# See comment in mean_absolute_error
534+
# Average across the outputs (if needed).
535+
# The second call to `_average` should always return
536+
# a scalar array that we convert to a Python float to
537+
# consistently return the same eager evaluated value.
538+
# Therefore, `axis=None`.
531539
mean_squared_error = _average(output_errors, weights=multioutput)
532-
assert mean_squared_error.shape == ()
540+
533541
return float(mean_squared_error)
534542

535543

@@ -585,13 +593,16 @@ def root_mean_squared_error(
585593
>>> y_true = [3, -0.5, 2, 7]
586594
>>> y_pred = [2.5, 0.0, 2, 8]
587595
>>> root_mean_squared_error(y_true, y_pred)
588-
np.float64(0.612...)
596+
0.612...
589597
>>> y_true = [[0.5, 1],[-1, 1],[7, -6]]
590598
>>> y_pred = [[0, 2],[-1, 2],[8, -5]]
591599
>>> root_mean_squared_error(y_true, y_pred)
592-
np.float64(0.822...)
600+
0.822...
593601
"""
594-
output_errors = np.sqrt(
602+
603+
xp, _ = get_namespace(y_true, y_pred, sample_weight, multioutput)
604+
605+
output_errors = xp.sqrt(
595606
mean_squared_error(
596607
y_true, y_pred, sample_weight=sample_weight, multioutput="raw_values"
597608
)
@@ -601,10 +612,17 @@ def root_mean_squared_error(
601612
if multioutput == "raw_values":
602613
return output_errors
603614
elif multioutput == "uniform_average":
604-
# pass None as weights to np.average: uniform mean
615+
# pass None as weights to _average: uniform mean
605616
multioutput = None
606617

607-
return np.average(output_errors, weights=multioutput)
618+
# Average across the outputs (if needed).
619+
# The second call to `_average` should always return
620+
# a scalar array that we convert to a Python float to
621+
# consistently return the same eager evaluated value.
622+
# Therefore, `axis=None`.
623+
root_mean_squared_error = _average(output_errors, weights=multioutput)
624+
625+
return float(root_mean_squared_error)
608626

609627

610628
@validate_params(
@@ -700,20 +718,22 @@ def mean_squared_log_error(
700718
y_true, y_pred, sample_weight=sample_weight, multioutput=multioutput
701719
)
702720

703-
y_type, y_true, y_pred, multioutput = _check_reg_targets(
704-
y_true, y_pred, multioutput
721+
xp, _ = get_namespace(y_true, y_pred)
722+
dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)
723+
724+
_, y_true, y_pred, _ = _check_reg_targets(
725+
y_true, y_pred, multioutput, dtype=dtype, xp=xp
705726
)
706-
check_consistent_length(y_true, y_pred, sample_weight)
707727

708-
if (y_true < 0).any() or (y_pred < 0).any():
728+
if xp.any(y_true <= -1) or xp.any(y_pred <= -1):
709729
raise ValueError(
710730
"Mean Squared Logarithmic Error cannot be used when "
711-
"targets contain negative values."
731+
"targets contain values less than or equal to -1."
712732
)
713733

714734
return mean_squared_error(
715-
np.log1p(y_true),
716-
np.log1p(y_pred),
735+
xp.log1p(y_true),
736+
xp.log1p(y_pred),
717737
sample_weight=sample_weight,
718738
multioutput=multioutput,
719739
)
@@ -773,20 +793,24 @@ def root_mean_squared_log_error(
773793
>>> y_true = [3, 5, 2.5, 7]
774794
>>> y_pred = [2.5, 5, 4, 8]
775795
>>> root_mean_squared_log_error(y_true, y_pred)
776-
np.float64(0.199...)
796+
0.199...
777797
"""
778-
_, y_true, y_pred, multioutput = _check_reg_targets(y_true, y_pred, multioutput)
779-
check_consistent_length(y_true, y_pred, sample_weight)
798+
xp, _ = get_namespace(y_true, y_pred)
799+
dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)
800+
801+
_, y_true, y_pred, multioutput = _check_reg_targets(
802+
y_true, y_pred, multioutput, dtype=dtype, xp=xp
803+
)
780804

781-
if (y_true < 0).any() or (y_pred < 0).any():
805+
if xp.any(y_true <= -1) or xp.any(y_pred <= -1):
782806
raise ValueError(
783807
"Root Mean Squared Logarithmic Error cannot be used when "
784-
"targets contain negative values."
808+
"targets contain values less than or equal to -1."
785809
)
786810

787811
return root_mean_squared_error(
788-
np.log1p(y_true),
789-
np.log1p(y_pred),
812+
xp.log1p(y_true),
813+
xp.log1p(y_pred),
790814
sample_weight=sample_weight,
791815
multioutput=multioutput,
792816
)

sklearn/metrics/tests/test_common.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
mean_pinball_loss,
3838
mean_poisson_deviance,
3939
mean_squared_error,
40+
mean_squared_log_error,
4041
mean_tweedie_deviance,
4142
median_absolute_error,
4243
multilabel_confusion_matrix,
@@ -47,6 +48,8 @@
4748
recall_score,
4849
roc_auc_score,
4950
roc_curve,
51+
root_mean_squared_error,
52+
root_mean_squared_log_error,
5053
top_k_accuracy_score,
5154
zero_one_loss,
5255
)
@@ -120,11 +123,14 @@
120123
"max_error": max_error,
121124
"mean_absolute_error": mean_absolute_error,
122125
"mean_squared_error": mean_squared_error,
126+
"mean_squared_log_error": mean_squared_log_error,
123127
"mean_pinball_loss": mean_pinball_loss,
124128
"median_absolute_error": median_absolute_error,
125129
"mean_absolute_percentage_error": mean_absolute_percentage_error,
126130
"explained_variance_score": explained_variance_score,
127131
"r2_score": partial(r2_score, multioutput="variance_weighted"),
132+
"root_mean_squared_error": root_mean_squared_error,
133+
"root_mean_squared_log_error": root_mean_squared_log_error,
128134
"mean_normal_deviance": partial(mean_tweedie_deviance, power=0),
129135
"mean_poisson_deviance": mean_poisson_deviance,
130136
"mean_gamma_deviance": mean_gamma_deviance,
@@ -458,7 +464,10 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
458464
"mean_absolute_error",
459465
"median_absolute_error",
460466
"mean_squared_error",
467+
"mean_squared_log_error",
461468
"r2_score",
469+
"root_mean_squared_error",
470+
"root_mean_squared_log_error",
462471
"explained_variance_score",
463472
"mean_absolute_percentage_error",
464473
"mean_pinball_loss",
@@ -482,6 +491,9 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
482491
"micro_f1_score",
483492
"macro_f1_score",
484493
"weighted_recall_score",
494+
"mean_squared_log_error",
495+
"root_mean_squared_error",
496+
"root_mean_squared_log_error",
485497
# P = R = F = accuracy in multiclass case
486498
"micro_f0.5_score",
487499
"micro_f1_score",
@@ -551,6 +563,12 @@ def precision_recall_curve_padded_thresholds(*args, **kwargs):
551563
"d2_tweedie_score",
552564
}
553565

566+
# Metrics involving y = log(1+x)
567+
METRICS_WITH_LOG1P_Y = {
568+
"mean_squared_log_error",
569+
"root_mean_squared_log_error",
570+
}
571+
554572

555573
def _require_positive_targets(y1, y2):
556574
"""Make targets strictly positive"""
@@ -560,6 +578,16 @@ def _require_positive_targets(y1, y2):
560578
return y1, y2
561579

562580

581+
def _require_log1p_targets(y1, y2):
582+
"""Make targets strictly larger than -1"""
583+
offset = abs(min(y1.min(), y2.min())) - 0.99
584+
y1 = y1.astype(float)
585+
y2 = y2.astype(float)
586+
y1 += offset
587+
y2 += offset
588+
return y1, y2
589+
590+
563591
def test_symmetry_consistency():
564592
# We shouldn't forget any metrics
565593
assert (
@@ -582,6 +610,9 @@ def test_symmetric_metric(name):
582610
if name in METRICS_REQUIRE_POSITIVE_Y:
583611
y_true, y_pred = _require_positive_targets(y_true, y_pred)
584612

613+
elif name in METRICS_WITH_LOG1P_Y:
614+
y_true, y_pred = _require_log1p_targets(y_true, y_pred)
615+
585616
y_true_bin = random_state.randint(0, 2, size=(20, 25))
586617
y_pred_bin = random_state.randint(0, 2, size=(20, 25))
587618

@@ -631,6 +662,8 @@ def test_sample_order_invariance(name):
631662

632663
if name in METRICS_REQUIRE_POSITIVE_Y:
633664
y_true, y_pred = _require_positive_targets(y_true, y_pred)
665+
elif name in METRICS_WITH_LOG1P_Y:
666+
y_true, y_pred = _require_log1p_targets(y_true, y_pred)
634667

635668
y_true_shuffle, y_pred_shuffle = shuffle(y_true, y_pred, random_state=0)
636669

@@ -698,6 +731,8 @@ def test_format_invariance_with_1d_vectors(name):
698731

699732
if name in METRICS_REQUIRE_POSITIVE_Y:
700733
y1, y2 = _require_positive_targets(y1, y2)
734+
elif name in METRICS_WITH_LOG1P_Y:
735+
y1, y2 = _require_log1p_targets(y1, y2)
701736

702737
y1_list = list(y1)
703738
y2_list = list(y2)
@@ -986,6 +1021,8 @@ def check_single_sample(name):
9861021
# assert that no exception is thrown
9871022
if name in METRICS_REQUIRE_POSITIVE_Y:
9881023
values = [1, 2]
1024+
elif name in METRICS_WITH_LOG1P_Y:
1025+
values = [-0.7, 1]
9891026
else:
9901027
values = [0, 1]
9911028
for i, j in product(values, repeat=2):
@@ -2017,6 +2054,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
20172054
check_array_api_regression_metric,
20182055
check_array_api_regression_metric_multioutput,
20192056
],
2057+
mean_squared_log_error: [
2058+
check_array_api_regression_metric,
2059+
check_array_api_regression_metric_multioutput,
2060+
],
20202061
d2_tweedie_score: [
20212062
check_array_api_regression_metric,
20222063
],
@@ -2036,6 +2077,14 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
20362077
linear_kernel: [check_array_api_metric_pairwise],
20372078
polynomial_kernel: [check_array_api_metric_pairwise],
20382079
rbf_kernel: [check_array_api_metric_pairwise],
2080+
root_mean_squared_error: [
2081+
check_array_api_regression_metric,
2082+
check_array_api_regression_metric_multioutput,
2083+
],
2084+
root_mean_squared_log_error: [
2085+
check_array_api_regression_metric,
2086+
check_array_api_regression_metric_multioutput,
2087+
],
20392088
sigmoid_kernel: [check_array_api_metric_pairwise],
20402089
}
20412090

sklearn/metrics/tests/test_regression.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -245,29 +245,33 @@ def test_regression_metrics_at_limits():
245245
assert_almost_equal(s([1, 1], [1, 1]), 1.0)
246246
assert_almost_equal(s([1, 1], [1, 1], force_finite=False), np.nan)
247247
msg = (
248-
"Mean Squared Logarithmic Error cannot be used when targets "
249-
"contain negative values."
248+
"Mean Squared Logarithmic Error cannot be used when "
249+
"targets contain values less than or equal to -1."
250250
)
251251
with pytest.raises(ValueError, match=msg):
252252
mean_squared_log_error([-1.0], [-1.0])
253253
msg = (
254-
"Mean Squared Logarithmic Error cannot be used when targets "
255-
"contain negative values."
254+
"Mean Squared Logarithmic Error cannot be used when "
255+
"targets contain values less than or equal to -1."
256256
)
257257
with pytest.raises(ValueError, match=msg):
258258
mean_squared_log_error([1.0, 2.0, 3.0], [1.0, -2.0, 3.0])
259259
msg = (
260-
"Mean Squared Logarithmic Error cannot be used when targets "
261-
"contain negative values."
260+
"Mean Squared Logarithmic Error cannot be used when "
261+
"targets contain values less than or equal to -1."
262262
)
263263
with pytest.raises(ValueError, match=msg):
264264
mean_squared_log_error([1.0, -2.0, 3.0], [1.0, 2.0, 3.0])
265265
msg = (
266-
"Root Mean Squared Logarithmic Error cannot be used when targets "
267-
"contain negative values."
266+
"Mean Squared Logarithmic Error cannot be used when "
267+
"targets contain values less than or equal to -1."
268268
)
269269
with pytest.raises(ValueError, match=msg):
270270
root_mean_squared_log_error([1.0, -2.0, 3.0], [1.0, 2.0, 3.0])
271+
msg = (
272+
"Root Mean Squared Logarithmic Error cannot be used when "
273+
"targets contain values less than or equal to -1."
274+
)
271275

272276
# Tweedie deviance error
273277
power = -1.2

0 commit comments

Comments
 (0)