Skip to content

ENH Array API support for confusion_matrix #30440

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from

Conversation

StefanieSenger
Copy link
Contributor

@StefanieSenger StefanieSenger commented Dec 9, 2024

Reference Issues/PRs

towards #26024

Edit: This PR is now superseeded by #30562

What does this implement/fix? Explain your changes.

This PR aims to add Array API support to confusion_matrix(). I have run the CUDA tests on Colab and they too, pass.

@OmarManzoor @ogrisel @lesteve: do you want to have a look?

Copy link

github-actions bot commented Dec 9, 2024

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 78b9612. Link to the linter CI: here

Comment on lines 377 to 384
if need_index_conversion:
label_to_ind = {y: x for x, y in enumerate(labels)}
y_pred = np.array([label_to_ind.get(x, n_labels + 1) for x in y_pred])
y_true = np.array([label_to_ind.get(x, n_labels + 1) for x in y_true])
y_pred = xp.asarray(
[label_to_ind.get(x, n_labels + 1) for x in y_pred], device=device_
)
y_true = xp.asarray(
[label_to_ind.get(x, n_labels + 1) for x in y_true], device=device_
)
Copy link
Contributor Author

@StefanieSenger StefanieSenger Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code block within the if need_index_conversion condition is only tested for ndarrays, because of the way our tests are written. It should work for the other array libraries that we currently integrate, but I feel we should in fact test this part of the code?

Copy link
Contributor Author

@StefanieSenger StefanieSenger Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It fails in label_to_ind = {y: x for x, y in enumerate(labels)} for array_api_strict, because these array elements are not hashable.

I will try to fix this (possibly by re-factoring, don't spoiler) and add a test.

Copy link
Contributor Author

@StefanieSenger StefanieSenger Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have asked here and they have explained me how to deal with it. It is fixed and I have added a test.

(I have also tried to re-factor, but not successfully yet as what I tried is much slower than then status quo and at least in this PR, I won't further try, since we have a working solution.)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily for this PR but maybe we could use unique_inverse

something like this would maybe work:

_, y_pred = xp.unique_inverse(y_pred)

Maybe a clearer example with array-api-strict:

In [1]: import array_api_strict
   ...: arr = array_api_strict.asarray([1, 3, 4, 1, 3, 5])
   ...: unique_values, labels = array_api_strict.unique_inverse(arr)
   ...: labels
   ...: 
Out[1]: Array([0, 1, 2, 0, 1, 3], dtype=array_api_strict.int64)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, looking a bit further, it looks even like we have sklearn.utils._encode._unique(return_inverse=True) that seems to have array API support already ...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, that looks pretty good. I will try it tomorrow. :)

Copy link
Contributor Author

@StefanieSenger StefanieSenger Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this PR, but for information, @lesteve:

I have tried to make it work with unique_inverse and a few other approaches, but I find it pretty complex to handle the lookup with only array operations, especially with strings allowed as dtypes. Maybe someone else wants to give it a try. I won't try any further, but I can share my experience, if interested.

In this PR, we're fine with the status quo of using the mapping.

Comment on lines 408 to 409
for true, pred, weight in zip(y_true, y_pred, sample_weight):
cm[true, pred] += weight
Copy link
Contributor Author

@StefanieSenger StefanieSenger Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this performant enough? I think it could be, because we are mostly dealing with small matrices at this point. But finding another way might be better. I am not sure how to do this with the tools available in array_api_strict though.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have the feeling that this loop will kill any computational benefit of array API support. We might as well ensure that y_true and y_pred are numpy arrays using _convert_to_numpy and rely on the coo_matrix trick instead. This would keep the code simpler.

That being said, I think convenience array API support for classification metrics that rely on confusion matrix internally is useful as discussed in #30439 (comment).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it seems like python loops do not go well with GPUs. However there doesn't seem to be an alternative with the array api because it doesn't support any sort of advanced indexing.

So either we might have to use the loop if we insist on following the array api or we could simply use the original code by utilizing the _convert_to_numpy as @ogrisel suggested.

Copy link
Contributor Author

@StefanieSenger StefanieSenger Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I ran this to check both options:

import timeit

code_to_test = """
import torch
from sklearn.metrics import confusion_matrix
from sklearn.base import config_context

xp=torch

y_true = xp.randint(0, 2, (2**23,), dtype=xp.int32)
y_pred = xp.randint(0, 2, (2**23,), dtype=xp.int32)
with config_context(array_api_dispatch=True):
    confusion_matrix(y_true,y_pred)
"""

execution_time = timeit.timeit(code_to_test, number=1)
print(execution_time)

The results are pretty clear, I am really surprised about the difference:

with python loop:
115.27976658102125

with _convert_to_numpy:
2.0403239529696293

The code for the python loop is the current state of this PR.
The code for the _convert_to_numpy version is in an extra branch: here

So, I think this speaks for the _convert_to_numpy option?

Edit: This was tested on cpu and I will next test it in Colab for data staying on cpu vs. gpu-cpu-gpu conversion.

Copy link
Contributor Author

@StefanieSenger StefanieSenger Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have executed this in Colab, where I had to massively reduce the size of the arrays, because otherwise Colab would interrupt itself and not finish executing. This was now tested on size 2**10:

with python loop and device='cpu': 2.3 s
with python loop and device='cuda': 7.9 s
with _convert_to_numpy and device='cpu': 2.9 s
with _convert_to_numpy and device='cuda': 2.7 s

I believe these results are not interpretable, because:

  1. executing everything just once makes the results really dependent on chance (the way we happen to create the random arrays), but I cannot increase this in Colab because of it's limitations, it seems (should I run 100 repetitions on smaller arrays then?)
  2. the large difference between both cpu solutions is not there anymore (though it seems unprobable than this large difference had happened by chance and I wonder what else could be the reason)
  3. it is so much slower compared to the 2**23 version on my own laptop and I cannot see how can this be?

I can see that I must be doing something wrong, but seems I need some feedback to improve this test.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just repeated the experiment using Scaleway VM (Type RENDER-S). I was able to execute the 2**23 array size versions for the python loop and got this:

with python loop:
execution_time for device cpu: 175.60608643199998
execution_time for device cuda: 468.15440756399994

with convert_to_numpy
execution_time for device cpu: 2.8455443269999705
execution_time for device cuda: 0.3635161219999645

So it seems the python loop is really having a bad effect, especially on cuda.

But there is still that problem with the number=1 repetitions, which makes our results object to chance. I will re-craft the test and report results.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have now re-defined the test code. This is the test I ran using Scaleway VM (Type RENDER-S):

import timeit

code_to_test = """
import torch
from sklearn.metrics import confusion_matrix
from sklearn.base import config_context

xp=torch

y_true = xp.randint(low=0, high=10, size=(int(1e5),), dtype=xp.int32, device="cpu")
y_pred = xp.randint(low=0, high=10, size=(int(1e5),), dtype=xp.int32, device="cpu")
with config_context(array_api_dispatch=True):
    confusion_matrix(y_true,y_pred)
"""

execution_time = timeit.timeit(code_to_test, number=50)
print(f'execution_time for device cpu: {execution_time}')

code_to_test = """
import torch
from sklearn.metrics import confusion_matrix
from sklearn.base import config_context

xp=torch

y_true = xp.randint(low=0, high=10, size=(int(1e5),), dtype=xp.int32, device="cuda")
y_pred = xp.randint(low=0, high=10, size=(int(1e5),), dtype=xp.int32, device="cuda")
with config_context(array_api_dispatch=True):
    confusion_matrix(y_true,y_pred)
"""

execution_time = timeit.timeit(code_to_test, number=50)
print(f'execution_time for device cuda: {execution_time}')

These are the results:

with python loop (current branch)

execution_time for device cpu: 69.83575903099972
execution_time for device cuda: 276.34000336899953

with convert_to_numpy (this branch)

execution_time for device cpu: 1.9843785450002542
execution_time for device cuda: 0.4149802029996863

This even more confirms that we should use convert_to_numpy and then the coo sparse matrix instead of the python loop.

I am happy with this test now. Did I misunderstand anything or did I forget to consider anything?

Sorry for the verbosity of my posts, I am just discovering both these topics: gpu arrays and performance testing, so my thoughts were not all straightforward all the time.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think the experiments you performed confirm our initial hypothesis that using python loops will have a negative impact on the performance of any array api related code.

Also it makes sense that cuda is more affected because that involves synchronization and communication between the gpu device and the cpu code in the case of using loops.

Thank you for running the experiments.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your confirmation, @OmarManzoor. 😃
I will then get rid of the python loop and use convert_to_numpy in combination with the coo_matrix then.

Copy link
Contributor Author

@StefanieSenger StefanieSenger Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And for completeness, since @lesteve hinted me to split off the setup from the actually tested code, which is only the call to confusion_matrix. I also ran this:

import timeit

setup = """
import torch
from sklearn.metrics import confusion_matrix
from sklearn.base import config_context

xp=torch

y_true = xp.randint(low=0, high=10, size=(int(1e5),), dtype=xp.int32, device="cpu")
y_pred = xp.randint(low=0, high=10, size=(int(1e5),), dtype=xp.int32, device="cpu")
"""

code_to_test = """
with config_context(array_api_dispatch=True):
    confusion_matrix(y_true,y_pred)
"""

execution_time = timeit.timeit(stmt=code_to_test, setup=setup, number=50)
print(f'execution_time for device cpu: {execution_time}')

setup = """
import torch
from sklearn.metrics import confusion_matrix
from sklearn.base import config_context

xp=torch

y_true = xp.randint(low=0, high=10, size=(int(1e5),), dtype=xp.int32, device="cuda")
y_pred = xp.randint(low=0, high=10, size=(int(1e5),), dtype=xp.int32, device="cuda")
"""

code_to_test = """
with config_context(array_api_dispatch=True):
    confusion_matrix(y_true,y_pred)
"""

execution_time = timeit.timeit(stmt=code_to_test, setup=setup, number=50)
print(f'execution_time for device cuda: {execution_time}')

And again the results:

with python loop:

execution_time for device cpu: 69.24619939000013
execution_time for device cuda: 276.69077064399994

with convert_to_numpy:

execution_time for device cpu: 0.28062641000002486
execution_time for device cuda: 0.3014569509999774

These results, again, confirm that the loop would have been a pretty bad choice. Goodbye, loop.

Edit: this test now also complies to this comment on not testing data creation in another PR.

@ogrisel ogrisel added the CUDA CI label Dec 9, 2024
@github-actions github-actions bot removed the CUDA CI label Dec 9, 2024
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @StefanieSenger

Copy link
Contributor Author

@StefanieSenger StefanieSenger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for reviewing, @OmarManzoor.
I have implemented your suggestions. Would you mind having another look?

(Currently still working on it: I found some problem. So no rush.) --> Edit: it's resolved.

@lesteve
Copy link
Member

lesteve commented Dec 19, 2024

My main comment is why don't we convert to numpy at the top of the function and then use the same numpy code as before. We end up converting to numpy array to create the sparse matrix anyway? I would expect that manipulating 1d arrays on GPU is not going to have a huge speed improvement compared to using numpy, maybe that's a naive assumption.

It seems like this kind of stuff has been discussed before, for example #28626 (comment), #30439 (comment), and #30439 (comment). Unfortunately, I have to admit that I don't have a good overview of it and whether there was an agreed general direction ...

My intuition is telling me that it would be easier in metrics to convert first to numpy and then to do the computation in numpy. I am fine with it not being "real Array API support". As an end-user I want to do a cross_val_score with the heavy computation using the GPU to have speed improvement (I guess .fit and maybe .predict) and not have any error. If GPU 1d arrays are converted to numpy arrays to compute metrics and that is not the bottleneck I don't care that much.

@OmarManzoor
Copy link
Contributor

My main comment is why don't we convert to numpy at the top of the function and then use the same numpy code as before. We end up converting to numpy array to create the sparse matrix anyway? I would expect that manipulating 1d arrays on GPU is not going to have a huge speed improvement compared to using numpy, maybe that's a naive assumption.

I think if the speed benefits gained from using the array api are not significant within the computation then we can probably just use numpy from start to end. However if there is benefit in using the array api in places where it can be used then I think it would be useful to use the array api and convert to numpy only where required.

@StefanieSenger
Copy link
Contributor Author

StefanieSenger commented Dec 23, 2024

@lesteve and @OmarManzoor:

I have checked how it would be if we converted all the input into numpy in the param validation in the top of the function and then return the confusion_matrix in whatever namespace the input came from in a separate branch (not sure if you can see the diff) These are the results:

execution_time for device cpu: 0.6608354229999804
execution_time for device cuda: 0.8810625369999343

And here, for comparison again the results from the current branch:

execution_time for device cpu: 2.0540889530000186
execution_time for device cuda: 0.6095863729999564

I ran the same test as here, but larger arrays (size=(int(1e6),)).

It seems that the numpy version is faster for numpy inputs and there is not a big difference for cuda, so I would tend to prefer the new branch version. Would you agree or is there something else to consider?

@OmarManzoor
Copy link
Contributor

OmarManzoor commented Dec 24, 2024

It seems that the numpy version is faster for numpy inputs and there is not a big difference for cuda, so I would tend to prefer the new branch version. Would you agree or is there something else to consider?

Looking at the results I think we might just convert to numpy and use the original code instead of complicating the logic inside this function with the array api.

I also ran some benchmarks on a kaggle notebook using arrays of size 1e8 and the results show a similar trend. Even though the CUDA version seems to be somewhat better consistently by using the array api, I don't think the performance gain is too dramatic.

RAM 29GB
GPU P100 16GB
array size 1e8

simple numpy by converting to numpy at the start

avg execution_time for device cpu: 7.236887979507446
avg execution_time for device cuda: 5.599829816818238

current code with array api

avg execution_time for device cpu: 11.78741683959961
avg execution_time for device cuda: 3.5200899124145506

Here is the short python script that I used

from time import time

import torch
from tqdm import tqdm

from sklearn.base import config_context
from sklearn.metrics import confusion_matrix

xp = torch

execution_times = []
for _ in tqdm(range(10), desc="CPU"):
    y_true = xp.randint(low=0, high=10, size=(int(1e8),), dtype=xp.int64, device="cpu")
    y_pred = xp.randint(low=0, high=10, size=(int(1e8),), dtype=xp.int64, device="cpu")
    start = time()
    with config_context(array_api_dispatch=True):
        confusion_matrix(y_true, y_pred)
    execution_times.append(time() - start)
print(f"avg execution_time for device cpu: {sum(execution_times) / 10}")

execution_times = []
for _ in tqdm(range(10), desc="CUDA"):
    y_true = xp.randint(low=0, high=10, size=(int(1e8),), dtype=xp.int64, device="cuda")
    y_pred = xp.randint(low=0, high=10, size=(int(1e8),), dtype=xp.int64, device="cuda")
    start = time()
    with config_context(array_api_dispatch=True):
        confusion_matrix(y_true, y_pred)
    execution_times.append(time() - start)
print(f"execution_time for device cuda: {sum(execution_times) / 10}")

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had started a review of this branch as such before considering the fact that the new branch might be better and would make some of the points moot.

The comments related to testing and documenting the output array namespace and device of the function are still valid for the other branch, though.

y_pred = xp.asarray([4, 5, 6], device=device)

with config_context(array_api_dispatch=True):
confusion_matrix(y_true, y_pred)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a check about the result to assert that the namespace of the result array is array_namespace.

The device is not necessarily the same, though (because we do not move back to the input device). I think it's fine to keep the result array on a CPU device even if the input arrays are GPU-allocated.

I don't think there is a library-agnostic way to check that a device object is a CPU device:

https://data-apis.org/array-api/latest/design_topics/device_support.html#semantics

Maybe we could check that the output

Suggested change
confusion_matrix(y_true, y_pred)
result = confusion_matrix(y_true, y_pred)
xp_result, device_result = get_namespace_and_device(result)
assert xp_result is xp
# Since the final computation always happens with NumPy / SciPy on
# the CPU, this function is expected to return an array allocated
# on the default device even when it does not match the input array's
# device.
default_device = device(xp.zeros(0))
assert device_result == default_device

If the last assertion does not work for any reason, I think it's fine not to test the result array device.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fact that we don't move back the data to the input arrays' device can be a bit surprising. Maybe we should document that somewhere, but I am not sure how.

Copy link
Member

@lucyleeow lucyleeow Feb 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about a table here-ish - we could also include info on whether support is 'surface' only and we actually do all computation in numpy, like what we've decided for confusion_matrix. Though I appreciate this isn't binary and in some functions part of the compute will be array api compliant or convert to numpy only for compiled functions etc

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lucyleeow you may know this already but my current understanding is that the confusion_matrix PR that is more likely to be merged is #30562.

I am a bit unsure in confusion_matrix about moving back to the original array namespace. I am slightly leaning towards doing it for consistency's sake, even if I am not entirely convinced it is that useful in practice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, don't worry I saw that, just continuing the discussion here as it was started here.

Comment on lines +1121 to +1126
if xp.isdtype(array.dtype, "real floating"):
array[xp.isinf(array) & (array > 0)] = xp.finfo(array.dtype).max
array[xp.isinf(array) & (array < 0)] = xp.finfo(array.dtype).min
else: # xp.isdtype(array.dtype, "integral")
array[xp.isinf(array) & (array > 0)] = xp.iinfo(array.dtype).max
array[xp.isinf(array) & (array < 0)] = xp.iinfo(array.dtype).min
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since xp.isinf(array) is always called twice in a row, let's reuse the result of the first call.

Suggested change
if xp.isdtype(array.dtype, "real floating"):
array[xp.isinf(array) & (array > 0)] = xp.finfo(array.dtype).max
array[xp.isinf(array) & (array < 0)] = xp.finfo(array.dtype).min
else: # xp.isdtype(array.dtype, "integral")
array[xp.isinf(array) & (array > 0)] = xp.iinfo(array.dtype).max
array[xp.isinf(array) & (array < 0)] = xp.iinfo(array.dtype).min
isinf_mask = xp.isinf(array)
if xp.isdtype(array.dtype, "real floating"):
array[isinf_mask & (array > 0)] = xp.finfo(array.dtype).max
array[isinf_mask & (array < 0)] = xp.finfo(array.dtype).min
else: # xp.isdtype(array.dtype, "integral")
array[isinf_mask & (array > 0)] = xp.iinfo(array.dtype).max
array[isinf_mask & (array < 0)] = xp.iinfo(array.dtype).min

minimum numbers available for the dtype respectively; like np.nan_to_num."""
xp, _ = get_namespace(array, xp=xp)
try:
array = xp.nan_to_num(array)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
array = xp.nan_to_num(array)
# nan_to_num is not part of the array API spec at this time but is generally
# consistently well adopted, so we anticipate a future inclusion in the spec:
# https://github.com/data-apis/array-api/issues/878
array = xp.nan_to_num(array)

shape=(n_labels, n_labels),
dtype=dtype,
).toarray()
cm = xp.asarray(cm)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could move this conversion to the last line (just before the return statement) and this would side-step the problems related to np.errstate-style error control not being part of the array API spec.

Since we do not move the cm back to the input device, using the array namespace for the normalization step below is not expected to yield any computational advantages.

This would also get rid of maintaining a fallback implementation of nan_to_num since this is not (yet?) part of the spec.

@ogrisel
Copy link
Member

ogrisel commented Dec 26, 2024

It seems that the numpy version is faster for numpy inputs and there is not a big difference for cuda, so I would tend to prefer the new branch version. Would you agree or is there something else to consider?

I agree, but I would not have expected that the NumPy or torch CPU performance to be that degraded in this branch compared to the other branch. This is visible both in your int(1e6) data benchmark and in @OmarManzoor's larger data benchmark above (#30440 (comment)).

It might be interesting to profile both branches with a large enough dataset to understand where this discrepancy comes from.

@ogrisel
Copy link
Member

ogrisel commented Dec 26, 2024

Note that in the other branch, you move back the results to the input arrays' device:

https://github.com/StefanieSenger/scikit-learn/blob/1f23f634d55a9f9b7182584ddb6e6afd37548d98/sklearn/metrics/_classification.py#L424

I am not sure if this is needed / useful:

  • on the plus side, it is less surprising to the caller;
  • on the minus side, it might cause unnecessary data transfer / synchronization overhead between the host and the device: in general confusion matrices are tiny matrices (n_classes is seldom larger than 1000) so I don't expect downstream computation to ever benefit from GPU computation, even when the inputs y_true and y_pred are huge arrays.

Shall we even keep the result as a NumPy array?

@ogrisel
Copy link
Member

ogrisel commented Dec 26, 2024

I did some profiling using py-spy and here is what I found:

  • when using a very large number of data points (int(1e8)), computing unique input values in y_true and y_pred dominates (more than 40% of the total time), even for numpy inputs that benefit from the unique value cache (sklearn.utils._unique.attach_unique/cached_unique);
  • the numpy-dtype based unique cache system does not work when using pytorch tensors (as expected...) and then computing unique values takes 80% or more of the total time because it happens several times in _check_targets and then in unique_labels;
  • the measured times vary a lot from run to run, and I am not sure if there is a significant difference between the two branches for numpy inputs.

However, converting early to numpy should allow benefiting from the caching of unique values if we convert to numpy before calling attach_unique.

@ogrisel
Copy link
Member

ogrisel commented Dec 26, 2024

@StefanieSenger could you please open a PR from your other branch (and cross-link to here) and then move the calls to attach_unique after the numpy conversion at the beginning of:

https://github.com/StefanieSenger/scikit-learn/blob/1f23f634d55a9f9b7182584ddb6e6afd37548d98/sklearn/metrics/_classification.py#L340-L367

@OmarManzoor
Copy link
Contributor

  • on the plus side, it is less surprising to the caller;
  • on the minus side, it might cause unnecessary data transfer / synchronization overhead between the host and the device: in general confusion matrices are tiny matrices (n_classes is seldom larger than 1000) so I don't expect downstream computation to ever benefit from GPU computation, even when the inputs y_true and y_pred are huge arrays.

Shall we even keep the result as a NumPy array?

I think we can keep the result as a NumPy array as this returns a metric which I doubt would be utilized in further computations that require a GPU specifically.

@lucascolley
Copy link
Contributor

Shall we even keep the result as a NumPy array?

I think we can keep the result as a NumPy array as this returns a metric which I doubt would be utilized in further computations that require a GPU specifically.

This is interesting and quite surprising to me, have you included this pattern anywhere else in scikit-learn? In SciPy, we have been quite careful to ensure that everything is namespace_in === namespace_out. I think there is a concern that telling users "mostly namespace_in === namespace_out, except ..." will be quite confusing. What if a user wants to do this?

# arrays of some namespace xp
x1 = confusion_matrix(y_true,y_pred)
z = xp.concat((x1, x2))

Since xp.concat may only accept xp arrays and not numpy arrays, you'd have to make sure users know to call xp.asarray on the confusion matrix. I think it would be pretty confusing if you have to do this for some sklearn functions and not others.

@lucascolley
Copy link
Contributor

I think that adding an option to short-circuit and return the result as a NumPy array sounds like a great idea, but I don't think it should be the default behaviour.

@StefanieSenger
Copy link
Contributor Author

StefanieSenger commented Jan 3, 2025

I think there is a concern that telling users "mostly namespace_in === namespace_out, except ..." will be quite confusing.

I agree. I also think that converting back to the namespace that was passed, wouldn't be very costly here.

@ogrisel
Copy link
Member

ogrisel commented Jan 5, 2025

I think that adding an option to short-circuit and return the result as a NumPy array sounds like a great idea, but I don't think it should be the default behaviour.

So far, we did not have to change the public API of scikit-learn to add array API support (beyond enabling/disabling array API dispatch via sklearn.config_context or sklearn.set_config). Since confusion_matrix is a public function, adding such an option would set a precedent. I would therefore only add such an argument to a private _confusion_matrix function meant to be called either by the public confusion_matrix function or directly by public metric functions that know whether it's useful to keep the result as a numpy array.

@OmarManzoor
Copy link
Contributor

OmarManzoor commented Jan 6, 2025

Since confusion_matrix is a public function, adding such an option would set a precedent. I would therefore only add such an argument to a private _confusion_matrix function meant to be called either by the public confusion_matrix function or directly by public metric functions that know whether it's useful to keep the result as a numpy array.

In that case wouldn't it be just better to cast just before returning to the intended namespace and device?

@lucyleeow lucyleeow added the Superseded PR has been replace by a newer PR label Feb 6, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Array API module:metrics Superseded PR has been replace by a newer PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants