Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/modules/array_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ Metrics
- :func:`sklearn.metrics.d2_tweedie_score`
- :func:`sklearn.metrics.explained_variance_score`
- :func:`sklearn.metrics.f1_score`
- :func:`sklearn.metrics.log_loss`
- :func:`sklearn.metrics.max_error`
- :func:`sklearn.metrics.mean_absolute_error`
- :func:`sklearn.metrics.mean_absolute_percentage_error`
Expand Down
2 changes: 2 additions & 0 deletions doc/whats_new/upcoming_changes/array-api/30439.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- :func:`sklearn.metrics.log_loss` now supports Array API compatible inputs.
by :user:`Omar Salman <OmarManzoor>`
37 changes: 22 additions & 15 deletions sklearn/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@
column_or_1d,
)
from ..utils._array_api import (
_allclose,
_average,
_bincount,
_convert_to_numpy,
_count_nonzero,
_find_matching_floating_dtype,
_is_numpy_namespace,
Expand All @@ -39,6 +41,7 @@
device,
get_namespace,
get_namespace_and_device,
supported_float_dtypes,
)
from ..utils._param_validation import (
Hidden,
Expand Down Expand Up @@ -2953,17 +2956,16 @@ def log_loss(y_true, y_pred, *, normalize=True, sample_weight=None, labels=None)
... [[.1, .9], [.9, .1], [.8, .2], [.35, .65]])
0.21616...
"""
y_pred = check_array(
y_pred, ensure_2d=False, dtype=[np.float64, np.float32, np.float16]
)
xp, _, device_ = get_namespace_and_device(y_true, y_pred, sample_weight, labels)
y_pred = check_array(y_pred, ensure_2d=False, dtype=supported_float_dtypes(xp=xp))

check_consistent_length(y_pred, y_true, sample_weight)
lb = LabelBinarizer()

if labels is not None:
lb.fit(labels)
lb.fit(_convert_to_numpy(labels, xp=xp))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is somewhat then missing the point of "supporting array API" here. I'd say we support array API if we don't convert to Numpy, and here we do. So in effect, there's not much of an improvement with this PR.

I think in order to get this merged, LabelBinarizer should support array API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that would be better, but I think we still perform computations after the LabelBinarizer part. Particularly the sums, clipping and xlogy, that might still bring some improvements as scipy's xlogy supports the array api.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might not be worth moving the data back and forth between devices

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think you are right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the description to reflect this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that means we need to shelf this PR until we fix label binarizer.

else:
lb.fit(y_true)
lb.fit(_convert_to_numpy(y_true, xp=xp))

if len(lb.classes_) == 1:
if labels is None:
Expand All @@ -2979,32 +2981,37 @@ def log_loss(y_true, y_pred, *, normalize=True, sample_weight=None, labels=None)
"got {0}.".format(lb.classes_)
)

transformed_labels = lb.transform(y_true)
float_dtype = _find_matching_floating_dtype(y_true, y_pred, xp=xp)
transformed_labels = xp.asarray(
lb.transform(_convert_to_numpy(y_true, xp=xp)),
dtype=float_dtype,
device=device_,
)

if transformed_labels.shape[1] == 1:
transformed_labels = np.append(
1 - transformed_labels, transformed_labels, axis=1
transformed_labels = xp.concat(
(1 - transformed_labels, transformed_labels), axis=1
)

# If y_pred is of single dimension, assume y_true to be binary
# and then check.
if y_pred.ndim == 1:
y_pred = y_pred[:, np.newaxis]
y_pred = y_pred[:, xp.newaxis]
if y_pred.shape[1] == 1:
y_pred = np.append(1 - y_pred, y_pred, axis=1)
y_pred = xp.concat((1 - y_pred, y_pred), axis=1)

eps = np.finfo(y_pred.dtype).eps
eps = xp.finfo(y_pred.dtype).eps

# Make sure y_pred is normalized
y_pred_sum = y_pred.sum(axis=1)
if not np.allclose(y_pred_sum, 1, rtol=np.sqrt(eps)):
y_pred_sum = xp.sum(y_pred, axis=1)
if not _allclose(y_pred_sum, 1, rtol=np.sqrt(eps), xp=xp):
Copy link
Contributor Author

@OmarManzoor OmarManzoor Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Since we are internally converting to numpy in all cases in the all_close helper, we can use np.sqrt instead of xp.sqrt here as this is just a scalar value and using xp.sqrt will require us to unnecessarily convert eps into an array to satisfy array api strict.

warnings.warn(
"The y_pred values do not sum to one. Make sure to pass probabilities.",
UserWarning,
)

# Clipping
y_pred = np.clip(y_pred, eps, 1 - eps)
y_pred = xp.clip(y_pred, eps, 1 - eps)

# Check if dimensions are consistent.
transformed_labels = check_array(transformed_labels)
Expand All @@ -3026,7 +3033,7 @@ def log_loss(y_true, y_pred, *, normalize=True, sample_weight=None, labels=None)
"labels: {0}".format(lb.classes_)
)

loss = -xlogy(transformed_labels, y_pred).sum(axis=1)
loss = xp.sum(-xlogy(transformed_labels, y_pred), axis=1)

return float(_average(loss, weights=sample_weight, normalize=normalize))

Expand Down
24 changes: 22 additions & 2 deletions sklearn/metrics/tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1831,7 +1831,11 @@ def check_array_api_binary_classification_metric(
metric, array_namespace, device, dtype_name
):
y_true_np = np.array([0, 0, 1, 1])
y_pred_np = np.array([0, 1, 0, 1])
# Log loss requires probabilities instead of raw labels.
if metric == log_loss:
y_pred_np = np.array([0.5, 0.2, 0.6, 0.7], dtype=dtype_name)
else:
y_pred_np = np.array([0, 1, 0, 1])

check_array_api_metric(
metric,
Expand Down Expand Up @@ -1860,7 +1864,19 @@ def check_array_api_multiclass_classification_metric(
metric, array_namespace, device, dtype_name
):
y_true_np = np.array([0, 1, 2, 3])
y_pred_np = np.array([0, 1, 0, 2])
# Log loss requires probabilities instead of raw labels.
if metric == log_loss:
y_pred_np = np.array(
[
[0.5, 0.2, 0.2, 0.1],
[0.4, 0.4, 0.1, 0.1],
[0.1, 0.1, 0.7, 0.1],
[0.1, 0.2, 0.6, 0.1],
],
dtype=dtype_name,
)
else:
y_pred_np = np.array([0, 1, 0, 2])

additional_params = {
"average": ("micro", "macro", "weighted"),
Expand Down Expand Up @@ -2066,6 +2082,10 @@ def check_array_api_metric_pairwise(metric, array_namespace, device, dtype_name)
check_array_api_multiclass_classification_metric,
check_array_api_multilabel_classification_metric,
],
log_loss: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
],
multilabel_confusion_matrix: [
check_array_api_binary_classification_metric,
check_array_api_multiclass_classification_metric,
Expand Down
13 changes: 13 additions & 0 deletions sklearn/utils/_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,3 +1101,16 @@ def _tolist(array, xp=None):
return array.tolist()
array_np = _convert_to_numpy(array, xp=xp)
return [element.item() for element in array_np]


def _allclose(a, b, rtol=1e-5, atol=1e-8, xp=None):
"""Internally converts the array inputs to numpy arrays and then uses
numpy's `allclose` function.

This helper function requires `a` to be an array whereas `b` can be an
array or a scalar.
"""
xp, _ = get_namespace(a, b, remove_types=(float, int, complex), xp=xp)
a_np = _convert_to_numpy(a, xp=xp)
b_np = b if isinstance(b, (int, float, complex)) else _convert_to_numpy(b, xp=xp)
return numpy.allclose(a_np, b_np, rtol=rtol, atol=atol)
Loading