Skip to content

FIX Apply dtype param in check_array_api_compute_metric unit test #27940

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

fcharras
Copy link
Contributor

Came accross this unfortunate mistake in a test for array api dispatch. A test is parameterized and scheduled for each dtype, but all those tests ignore the dtype parameters so they all run the same thing.

Fixed by actually parsing the dtype parameter and applying it.

Copy link

github-actions bot commented Dec 11, 2023

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 27bc656. Link to the linter CI: here

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. I think the code would be more maintainable and less prone to not doing what we expect to do if we adopted the following naming conventions for argument and local variables in array api tests and testing utilities that typically mix np and xp function/method calls:

  • dtype_name: a string such as "float32", "uint8"...
  • xp_dtype: an dtype attribute of the xp module: xp_dtype = getattr(xp, dtype_name),
  • np_dtype: a numpy dtype object: xp_dtype = getattr(np, dtype_name).

For regular library code (outside of the tests), where there is no ambiguity, then we can keep on using the dtype variable name instead of the latter two.

@fcharras fcharras force-pushed the FIX/dtype_not_applied_in_array_api_metrics_test branch 2 times, most recently from c2779fc to 60c0aee Compare December 12, 2023 14:25
@fcharras fcharras force-pushed the FIX/dtype_not_applied_in_array_api_metrics_test branch from d2d356b to 14a2b5c Compare December 12, 2023 16:17
@fcharras
Copy link
Contributor Author

fcharras commented Dec 13, 2023

I renamed dtype with dtype_name accross tests for Array API, I hope I covered it all. Tests pass with torch and cuda enabled:

Click for test logs

pytest -v sklearn -k torch

collected 33000 items / 32828 deselected / 2 skipped / 172 selected                                                                                                                                                                        

sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-torch-cpu-float64] PASSED                                                              [  0%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-torch-cpu-float32] PASSED                                                              [  1%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-torch-cuda-float64] PASSED                                                             [  1%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-torch-cuda-float32] PASSED                                                             [  2%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_input_and_values-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FA...) [  2%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_get_precision-torch-cpu-float64] PASSED                                                                 [  3%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_get_precision-torch-cpu-float32] PASSED                                                                 [  4%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_get_precision-torch-cuda-float64] PASSED                                                                [  4%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_get_precision-torch-cuda-float32] PASSED                                                                [  5%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,svd_solver='full')-check_array_api_get_precision-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLB...) [  5%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_input_and_values-torch-cpu-float64] PASSED                                                [  6%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_input_and_values-torch-cpu-float32] PASSED                                                [  6%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_input_and_values-torch-cuda-float64] PASSED                                               [  7%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_input_and_values-torch-cuda-float32] PASSED                                               [  8%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_input_and_values-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH...) [  8%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_get_precision-torch-cpu-float64] PASSED                                                   [  9%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_get_precision-torch-cpu-float32] PASSED                                                   [  9%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_get_precision-torch-cuda-float64] PASSED                                                  [ 10%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_get_precision-torch-cuda-float32] PASSED                                                  [ 11%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=0.1,svd_solver='full',whiten=True)-check_array_api_get_precision-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_EN...) [ 11%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-torch-cpu-float64] PASSED         [ 12%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-torch-cpu-float32] PASSED         [ 12%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-torch-cuda-float64] PASSED        [ 13%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-torch-cuda-float32] PASSED        [ 13%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_input_and_values-torch-mps-float32] SKIPPED (S...) [ 14%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-torch-cpu-float64] PASSED            [ 15%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-torch-cpu-float32] PASSED            [ 15%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-torch-cuda-float64] PASSED           [ 16%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-torch-cuda-float32] PASSED           [ 16%]
sklearn/decomposition/tests/test_pca.py::test_pca_array_api_compliance[PCA(n_components=2,power_iteration_normalizer='QR',random_state=0,svd_solver='randomized')-check_array_api_get_precision-torch-mps-float32] SKIPPED (Skip...) [ 17%]
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-torch-cpu-float64] PASSED                                                                 [ 18%]
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-torch-cpu-float32] PASSED                                                                 [ 18%]
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-torch-cuda-float64] PASSED                                                                [ 19%]
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-torch-cuda-float32] PASSED                                                                [ 19%]
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLB...) [ 20%]
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-torch-cpu-float64] PASSED                                                         [ 20%]
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-torch-cpu-float32] PASSED                                                         [ 21%]
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-torch-cuda-float64] PASSED                                                        [ 22%]
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-torch-cuda-float32] PASSED                                                        [ 22%]
sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_get_precision-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_M...) [ 23%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_binary_classification_metric-torch-cpu-float64] PASSED                                                                                [ 23%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_binary_classification_metric-torch-cpu-float32] PASSED                                                                                [ 24%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_binary_classification_metric-torch-cuda-float64] PASSED                                                                               [ 25%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_binary_classification_metric-torch-cuda-float32] PASSED                                                                               [ 25%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_binary_classification_metric-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)    [ 26%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multiclass_classification_metric-torch-cpu-float64] PASSED                                                                            [ 26%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multiclass_classification_metric-torch-cpu-float32] PASSED                                                                            [ 27%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multiclass_classification_metric-torch-cuda-float64] PASSED                                                                           [ 27%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multiclass_classification_metric-torch-cuda-float32] PASSED                                                                           [ 28%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multiclass_classification_metric-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not ...) [ 29%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_binary_classification_metric-torch-cpu-float64] PASSED                                                                                 [ 29%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_binary_classification_metric-torch-cpu-float32] PASSED                                                                                 [ 30%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_binary_classification_metric-torch-cuda-float64] PASSED                                                                                [ 30%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_binary_classification_metric-torch-cuda-float32] PASSED                                                                                [ 31%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_binary_classification_metric-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)     [ 31%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multiclass_classification_metric-torch-cpu-float64] PASSED                                                                             [ 32%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multiclass_classification_metric-torch-cpu-float32] PASSED                                                                             [ 33%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multiclass_classification_metric-torch-cuda-float64] PASSED                                                                            [ 33%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multiclass_classification_metric-torch-cuda-float32] PASSED                                                                            [ 34%]
sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multiclass_classification_metric-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.) [ 34%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-torch-cpu-float64] PASSED                                                                                                                     [ 35%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-torch-cpu-float32] PASSED                                                                                                                     [ 36%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-torch-cuda-float64] PASSED                                                                                                                    [ 36%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-torch-cuda-float32] PASSED                                                                                                                    [ 37%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-None-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)                                         [ 37%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-torch-cpu-float64] PASSED                                                                                                                [ 38%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-torch-cpu-float32] PASSED                                                                                                                [ 38%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-torch-cuda-float64] PASSED                                                                                                               [ 39%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-torch-cuda-float32] PASSED                                                                                                               [ 40%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[True-stratify1-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)                                    [ 40%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-torch-cpu-float64] PASSED                                                                                                                    [ 41%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-torch-cpu-float32] PASSED                                                                                                                    [ 41%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-torch-cuda-float64] PASSED                                                                                                                   [ 42%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-torch-cuda-float32] PASSED                                                                                                                   [ 43%]
sklearn/model_selection/tests/test_split.py::test_array_api_train_test_split[False-None-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)                                        [ 43%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[MaxAbsScaler()-check_array_api_input_and_values-torch-cpu-float64] PASSED                                                                                 [ 44%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[MaxAbsScaler()-check_array_api_input_and_values-torch-cpu-float32] PASSED                                                                                 [ 44%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[MaxAbsScaler()-check_array_api_input_and_values-torch-cuda-float64] PASSED                                                                                [ 45%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[MaxAbsScaler()-check_array_api_input_and_values-torch-cuda-float32] PASSED                                                                                [ 45%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[MaxAbsScaler()-check_array_api_input_and_values-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)     [ 46%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[MinMaxScaler()-check_array_api_input_and_values-torch-cpu-float64] PASSED                                                                                 [ 47%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[MinMaxScaler()-check_array_api_input_and_values-torch-cpu-float32] PASSED                                                                                 [ 47%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[MinMaxScaler()-check_array_api_input_and_values-torch-cuda-float64] PASSED                                                                                [ 48%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[MinMaxScaler()-check_array_api_input_and_values-torch-cuda-float32] PASSED                                                                                [ 48%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[MinMaxScaler()-check_array_api_input_and_values-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)     [ 49%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[KernelCenterer()-check_array_api_input_and_values-torch-cpu-float64] PASSED                                                                               [ 50%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[KernelCenterer()-check_array_api_input_and_values-torch-cpu-float32] PASSED                                                                               [ 50%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[KernelCenterer()-check_array_api_input_and_values-torch-cuda-float64] PASSED                                                                              [ 51%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[KernelCenterer()-check_array_api_input_and_values-torch-cuda-float32] PASSED                                                                              [ 51%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[KernelCenterer()-check_array_api_input_and_values-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)   [ 52%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer(norm='l1')-check_array_api_input_and_values-torch-cpu-float64] PASSED                                                                          [ 52%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer(norm='l1')-check_array_api_input_and_values-torch-cpu-float32] PASSED                                                                          [ 53%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer(norm='l1')-check_array_api_input_and_values-torch-cuda-float64] PASSED                                                                         [ 54%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer(norm='l1')-check_array_api_input_and_values-torch-cuda-float32] PASSED                                                                         [ 54%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer(norm='l1')-check_array_api_input_and_values-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is no...) [ 55%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer()-check_array_api_input_and_values-torch-cpu-float64] PASSED                                                                                   [ 55%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer()-check_array_api_input_and_values-torch-cpu-float32] PASSED                                                                                   [ 56%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer()-check_array_api_input_and_values-torch-cuda-float64] PASSED                                                                                  [ 56%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer()-check_array_api_input_and_values-torch-cuda-float32] PASSED                                                                                  [ 57%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer()-check_array_api_input_and_values-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)       [ 58%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer(norm='max')-check_array_api_input_and_values-torch-cpu-float64] PASSED                                                                         [ 58%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer(norm='max')-check_array_api_input_and_values-torch-cpu-float32] PASSED                                                                         [ 59%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer(norm='max')-check_array_api_input_and_values-torch-cuda-float64] PASSED                                                                        [ 59%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer(norm='max')-check_array_api_input_and_values-torch-cuda-float32] PASSED                                                                        [ 60%]
sklearn/preprocessing/tests/test_data.py::test_scaler_array_api_compliance[Normalizer(norm='max')-check_array_api_input_and_values-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is n...) [ 61%]
sklearn/tests/test_common.py::test_estimators[KernelCenterer()-check_array_api_input(array_namespace=torch,dtype_name=float64,device=cpu)] PASSED                                                                                    [ 61%]
sklearn/tests/test_common.py::test_estimators[KernelCenterer()-check_array_api_input(array_namespace=torch,dtype_name=float32,device=cpu)] PASSED                                                                                    [ 62%]
sklearn/tests/test_common.py::test_estimators[KernelCenterer()-check_array_api_input(array_namespace=torch,dtype_name=float64,device=cuda)] PASSED                                                                                   [ 62%]
sklearn/tests/test_common.py::test_estimators[KernelCenterer()-check_array_api_input(array_namespace=torch,dtype_name=float32,device=cuda)] PASSED                                                                                   [ 63%]
sklearn/tests/test_common.py::test_estimators[KernelCenterer()-check_array_api_input(array_namespace=torch,dtype_name=float32,device=mps)] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)        [ 63%]
sklearn/tests/test_common.py::test_estimators[LinearDiscriminantAnalysis()-check_array_api_input(array_namespace=torch,dtype_name=float64,device=cpu)] PASSED                                                                        [ 64%]
sklearn/tests/test_common.py::test_estimators[LinearDiscriminantAnalysis()-check_array_api_input(array_namespace=torch,dtype_name=float32,device=cpu)] PASSED                                                                        [ 65%]
sklearn/tests/test_common.py::test_estimators[LinearDiscriminantAnalysis()-check_array_api_input(array_namespace=torch,dtype_name=float64,device=cuda)] PASSED                                                                       [ 65%]
sklearn/tests/test_common.py::test_estimators[LinearDiscriminantAnalysis()-check_array_api_input(array_namespace=torch,dtype_name=float32,device=cuda)] PASSED                                                                       [ 66%]
sklearn/tests/test_common.py::test_estimators[LinearDiscriminantAnalysis()-check_array_api_input(array_namespace=torch,dtype_name=float32,device=mps)] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is ...) [ 66%]
sklearn/tests/test_common.py::test_estimators[Normalizer()-check_array_api_input(array_namespace=torch,dtype_name=float64,device=cpu)] PASSED                                                                                        [ 67%]
sklearn/tests/test_common.py::test_estimators[Normalizer()-check_array_api_input(array_namespace=torch,dtype_name=float32,device=cpu)] PASSED                                                                                        [ 68%]
sklearn/tests/test_common.py::test_estimators[Normalizer()-check_array_api_input(array_namespace=torch,dtype_name=float64,device=cuda)] PASSED                                                                                       [ 68%]
sklearn/tests/test_common.py::test_estimators[Normalizer()-check_array_api_input(array_namespace=torch,dtype_name=float32,device=cuda)] PASSED                                                                                       [ 69%]
sklearn/tests/test_common.py::test_estimators[Normalizer()-check_array_api_input(array_namespace=torch,dtype_name=float32,device=mps)] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)            [ 69%]
sklearn/tests/test_common.py::test_estimators[PCA()-check_array_api_input(array_namespace=torch,dtype_name=float64,device=cpu)] PASSED                                                                                               [ 70%]
sklearn/tests/test_common.py::test_estimators[PCA()-check_array_api_input(array_namespace=torch,dtype_name=float32,device=cpu)] PASSED                                                                                               [ 70%]
sklearn/tests/test_common.py::test_estimators[PCA()-check_array_api_input(array_namespace=torch,dtype_name=float64,device=cuda)] PASSED                                                                                              [ 71%]
sklearn/tests/test_common.py::test_estimators[PCA()-check_array_api_input(array_namespace=torch,dtype_name=float32,device=cuda)] PASSED                                                                                              [ 72%]
sklearn/tests/test_common.py::test_estimators[PCA()-check_array_api_input(array_namespace=torch,dtype_name=float32,device=mps)] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)                   [ 72%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[None-False-10.0-torch-cpu-float64] PASSED                                                                                                                                   [ 73%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[None-False-10.0-torch-cpu-float32] PASSED                                                                                                                                   [ 73%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[None-False-10.0-torch-cuda-float64] PASSED                                                                                                                                  [ 74%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[None-False-10.0-torch-cuda-float32] PASSED                                                                                                                                  [ 75%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[None-False-10.0-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)                                                       [ 75%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[None-True-2.5-torch-cpu-float64] PASSED                                                                                                                                     [ 76%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[None-True-2.5-torch-cpu-float32] PASSED                                                                                                                                     [ 76%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[None-True-2.5-torch-cuda-float64] PASSED                                                                                                                                    [ 77%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[None-True-2.5-torch-cuda-float32] PASSED                                                                                                                                    [ 77%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[None-True-2.5-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)                                                         [ 78%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight2-False-5.5-torch-cpu-float64] PASSED                                                                                                                          [ 79%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight2-False-5.5-torch-cpu-float32] PASSED                                                                                                                          [ 79%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight2-False-5.5-torch-cuda-float64] PASSED                                                                                                                         [ 80%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight2-False-5.5-torch-cuda-float32] PASSED                                                                                                                         [ 80%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight2-False-5.5-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)                                              [ 81%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight3-True-2.75-torch-cpu-float64] PASSED                                                                                                                          [ 81%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight3-True-2.75-torch-cpu-float32] PASSED                                                                                                                          [ 82%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight3-True-2.75-torch-cuda-float64] PASSED                                                                                                                         [ 83%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight3-True-2.75-torch-cuda-float32] PASSED                                                                                                                         [ 83%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight3-True-2.75-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)                                              [ 84%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight4-False-30.0-torch-cpu-float64] PASSED                                                                                                                         [ 84%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight4-False-30.0-torch-cpu-float32] PASSED                                                                                                                         [ 85%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight4-False-30.0-torch-cuda-float64] PASSED                                                                                                                        [ 86%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight4-False-30.0-torch-cuda-float32] PASSED                                                                                                                        [ 86%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight4-False-30.0-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)                                             [ 87%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight5-True-3.0-torch-cpu-float64] PASSED                                                                                                                           [ 87%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight5-True-3.0-torch-cpu-float32] PASSED                                                                                                                           [ 88%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight5-True-3.0-torch-cuda-float64] PASSED                                                                                                                          [ 88%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight5-True-3.0-torch-cuda-float32] PASSED                                                                                                                          [ 89%]
sklearn/utils/tests/test_array_api.py::test_weighted_sum[sample_weight5-True-3.0-torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)                                               [ 90%]
sklearn/utils/tests/test_array_api.py::test_nan_reductions[X0-_nanmin-1-torch] PASSED                                                                                                                                                [ 90%]
sklearn/utils/tests/test_array_api.py::test_nan_reductions[X1-_nanmin--2-torch] PASSED                                                                                                                                               [ 91%]
sklearn/utils/tests/test_array_api.py::test_nan_reductions[X2-_nanmin-inf-torch] PASSED                                                                                                                                              [ 91%]
sklearn/utils/tests/test_array_api.py::test_nan_reductions[X3-reduction3-expected3-torch] PASSED                                                                                                                                     [ 92%]
sklearn/utils/tests/test_array_api.py::test_nan_reductions[X4-reduction4-expected4-torch] PASSED                                                                                                                                     [ 93%]
sklearn/utils/tests/test_array_api.py::test_nan_reductions[X5-_nanmax-2-torch] PASSED                                                                                                                                                [ 93%]
sklearn/utils/tests/test_array_api.py::test_nan_reductions[X6-_nanmax-2-torch] PASSED                                                                                                                                                [ 94%]
sklearn/utils/tests/test_array_api.py::test_nan_reductions[X7-_nanmax--inf-torch] PASSED                                                                                                                                             [ 94%]
sklearn/utils/tests/test_array_api.py::test_nan_reductions[X8-reduction8-expected8-torch] PASSED                                                                                                                                     [ 95%]
sklearn/utils/tests/test_array_api.py::test_nan_reductions[X9-reduction9-expected9-torch] PASSED                                                                                                                                     [ 95%]
sklearn/utils/tests/test_array_api.py::test_convert_to_numpy_gpu[torch] PASSED                                                                                                                                                       [ 96%]
sklearn/utils/tests/test_array_api.py::test_convert_estimator_to_ndarray[torch-<lambda>] PASSED                                                                                                                                      [ 97%]
sklearn/utils/tests/test_multiclass.py::test_is_multilabel_array_api_compliance[torch-cpu-float64] PASSED                                                                                                                            [ 97%]
sklearn/utils/tests/test_multiclass.py::test_is_multilabel_array_api_compliance[torch-cpu-float32] PASSED                                                                                                                            [ 98%]
sklearn/utils/tests/test_multiclass.py::test_is_multilabel_array_api_compliance[torch-cuda-float64] PASSED                                                                                                                           [ 98%]
sklearn/utils/tests/test_multiclass.py::test_is_multilabel_array_api_compliance[torch-cuda-float32] PASSED                                                                                                                           [ 99%]
sklearn/utils/tests/test_multiclass.py::test_is_multilabel_array_api_compliance[torch-mps-float32] SKIPPED (Skipping MPS device test because PYTORCH_ENABLE_MPS_FALLBACK is not set.)                                                [100%]

===================================================================================== 140 passed, 34 skipped, 32828 deselected, 76 warnings in 13.18s ======================================================================================

I haven't found areas where xp_dtype and np_dtype would be necessary. Also, in the module sklearn.utils._array_api it's less prone to confusion because numpy is used as numpy rather than np so it cannot be confused with xp.

@fcharras fcharras requested a review from ogrisel December 13, 2023 09:39
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once this https://github.com/scikit-learn/scikit-learn/pull/27940/files#r1425130338 is addressed, LGTM even if codecov complains.

@fcharras fcharras force-pushed the FIX/dtype_not_applied_in_array_api_metrics_test branch from f1dcfbf to efbb30a Compare December 14, 2023 12:46
y_pred_np,
sample_weight=None,
)
sample_weight = xp.asarray(sample_weight, device=device)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said IRL, I would rather avoid such clever recursive calls that hides the intention of the original code when reading a traceback in the CI logs in case of test failures) and better use a flatter, more direct/explicit style in the caller (even if slightly more redundant).

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other than the above comment, LGTM! Thanks @fcharras.

@glemaitre glemaitre self-requested a review December 18, 2023 15:04
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks good but I have 2 types of failure locally with MPS support:

================================================================= FAILURES ==================================================================
___________ test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-torch-cpu-float32] ____________

estimator = PCA(n_components='mle', svd_solver='full'), check = <function check_array_api_input at 0x133ab89d0>, array_namespace = 'torch'
device = 'cpu', dtype_name = 'float32'

    @pytest.mark.parametrize(
        "array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
    )
    @pytest.mark.parametrize(
        "check",
        [check_array_api_input, check_array_api_get_precision],
        ids=_get_check_estimator_ids,
    )
    @pytest.mark.parametrize(
        "estimator",
        [
            # PCA with mle cannot use check_array_api_input_and_values because of
            # rounding errors in the noisy (low variance) components.
            PCA(n_components="mle", svd_solver="full"),
        ],
        ids=_get_check_estimator_ids,
    )
    def test_pca_mle_array_api_compliance(
        estimator, check, array_namespace, device, dtype_name
    ):
        name = estimator.__class__.__name__
>       check(name, estimator, array_namespace, device=device, dtype_name=dtype_name)

sklearn/decomposition/tests/test_pca.py:901: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

name = 'PCA', estimator_orig = PCA(n_components='mle', svd_solver='full'), array_namespace = 'torch', device = 'cpu', dtype_name = 'float32'
check_values = False

    def check_array_api_input(
        name,
        estimator_orig,
        array_namespace,
        device=None,
        dtype_name="float64",
        check_values=False,
    ):
        """Check that the estimator can work consistently with the Array API
    
        By default, this just checks that the types and shapes of the arrays are
        consistent with calling the same estimator with numpy arrays.
    
        When check_values is True, it also checks that calling the estimator on the
        array_api Array gives the same results as ndarrays.
        """
        xp = _array_api_for_tests(array_namespace, device)
    
        X, y = make_classification(random_state=42)
        X = X.astype(dtype_name, copy=False)
    
        X = _enforce_estimator_tags_X(estimator_orig, X)
        y = _enforce_estimator_tags_y(estimator_orig, y)
    
        est = clone(estimator_orig)
    
        X_xp = xp.asarray(X, device=device)
        y_xp = xp.asarray(y, device=device)
    
        est.fit(X, y)
    
        array_attributes = {
            key: value for key, value in vars(est).items() if isinstance(value, np.ndarray)
        }
    
        est_xp = clone(est)
        with config_context(array_api_dispatch=True):
            est_xp.fit(X_xp, y_xp)
            input_ns = get_namespace(X_xp)[0].__name__
    
        # Fitted attributes which are arrays must have the same
        # namespace as the one of the training data.
        for key, attribute in array_attributes.items():
            est_xp_param = getattr(est_xp, key)
            with config_context(array_api_dispatch=True):
                attribute_ns = get_namespace(est_xp_param)[0].__name__
            assert attribute_ns == input_ns, (
                f"'{key}' attribute is in wrong namespace, expected {input_ns} "
                f"got {attribute_ns}"
            )
    
            assert array_device(est_xp_param) == array_device(X_xp)
    
            est_xp_param_np = _convert_to_numpy(est_xp_param, xp=xp)
            if check_values:
                assert_allclose(
                    attribute,
                    est_xp_param_np,
                    err_msg=f"{key} not the same",
                    atol=np.finfo(X.dtype).eps * 100,
                )
            else:
>               assert attribute.shape == est_xp_param_np.shape
E               AssertionError

sklearn/utils/estimator_checks.py:928: AssertionError
_________________ test_array_api_compliance[accuracy_score-check_array_api_binary_classification_metric-torch-mps-float32] __________________

metric = <function accuracy_score at 0x130ba4040>, array_namespace = 'torch', device = 'mps', dtype_name = 'float32'
check_func = <function check_array_api_binary_classification_metric at 0x14d50c670>

    @pytest.mark.parametrize(
        "array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
    )
    @pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
    def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):
>       check_func(metric, array_namespace, device, dtype_name)

sklearn/metrics/tests/test_common.py:1839: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
sklearn/metrics/tests/test_common.py:1776: in check_array_api_binary_classification_metric
    check_array_api_metric(
sklearn/metrics/tests/test_common.py:1749: in check_array_api_metric
    metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight)
sklearn/utils/_param_validation.py:213: in wrapper
    return func(*args, **kwargs)
sklearn/metrics/_classification.py:221: in accuracy_score
    return _weighted_sum(score, sample_weight, normalize)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

sample_score = tensor([1., 0., 0., 1.], dtype=torch.float64), sample_weight = tensor([0.0000, 0.1000, 2.0000, 1.0000], device='mps:0')
normalize = True
xp = <module 'array_api_compat.torch' from '/Users/glemaitre/mambaforge/envs/sklearn_dev/lib/python3.10/site-packages/array_api_compat/torch/__init__.py'>

    def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
        # XXX: this function accepts Array API input but returns a Python scalar
        # float. The call to float() is convenient because it removes the need to
        # move back results from device to host memory (e.g. calling `.cpu()` on a
        # torch tensor). However, this might interact in unexpected ways (break?)
        # with lazy Array API implementations. See:
        # https://github.com/data-apis/array-api/issues/642
        if xp is None:
            xp, _ = get_namespace(sample_score)
        if normalize and _is_numpy_namespace(xp):
            sample_score_np = numpy.asarray(sample_score)
            if sample_weight is not None:
                sample_weight_np = numpy.asarray(sample_weight)
            else:
                sample_weight_np = None
            return float(numpy.average(sample_score_np, weights=sample_weight_np))
    
        if not xp.isdtype(sample_score.dtype, "real floating"):
            # We move to cpu device ahead of time since certain devices may not support
            # float64, but we want the same precision for all devices and namespaces.
            sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64)
    
        if sample_weight is not None:
>           sample_weight = xp.asarray(
                sample_weight, dtype=sample_score.dtype, device=device(sample_score)
            )
E           TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

sklearn/utils/_array_api.py:447: TypeError
_______________ test_array_api_compliance[accuracy_score-check_array_api_multiclass_classification_metric-torch-mps-float32] ________________

metric = <function accuracy_score at 0x130ba4040>, array_namespace = 'torch', device = 'mps', dtype_name = 'float32'
check_func = <function check_array_api_multiclass_classification_metric at 0x14d50c700>

    @pytest.mark.parametrize(
        "array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
    )
    @pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
    def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):
>       check_func(metric, array_namespace, device, dtype_name)

sklearn/metrics/tests/test_common.py:1839: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
sklearn/metrics/tests/test_common.py:1805: in check_array_api_multiclass_classification_metric
    check_array_api_metric(
sklearn/metrics/tests/test_common.py:1749: in check_array_api_metric
    metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight)
sklearn/utils/_param_validation.py:213: in wrapper
    return func(*args, **kwargs)
sklearn/metrics/_classification.py:221: in accuracy_score
    return _weighted_sum(score, sample_weight, normalize)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

sample_score = tensor([1., 1., 0., 0.], dtype=torch.float64), sample_weight = tensor([0.0000, 0.1000, 2.0000, 1.0000], device='mps:0')
normalize = True
xp = <module 'array_api_compat.torch' from '/Users/glemaitre/mambaforge/envs/sklearn_dev/lib/python3.10/site-packages/array_api_compat/torch/__init__.py'>

    def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
        # XXX: this function accepts Array API input but returns a Python scalar
        # float. The call to float() is convenient because it removes the need to
        # move back results from device to host memory (e.g. calling `.cpu()` on a
        # torch tensor). However, this might interact in unexpected ways (break?)
        # with lazy Array API implementations. See:
        # https://github.com/data-apis/array-api/issues/642
        if xp is None:
            xp, _ = get_namespace(sample_score)
        if normalize and _is_numpy_namespace(xp):
            sample_score_np = numpy.asarray(sample_score)
            if sample_weight is not None:
                sample_weight_np = numpy.asarray(sample_weight)
            else:
                sample_weight_np = None
            return float(numpy.average(sample_score_np, weights=sample_weight_np))
    
        if not xp.isdtype(sample_score.dtype, "real floating"):
            # We move to cpu device ahead of time since certain devices may not support
            # float64, but we want the same precision for all devices and namespaces.
            sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64)
    
        if sample_weight is not None:
>           sample_weight = xp.asarray(
                sample_weight, dtype=sample_score.dtype, device=device(sample_score)
            )
E           TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

sklearn/utils/_array_api.py:447: TypeError
__________________ test_array_api_compliance[zero_one_loss-check_array_api_binary_classification_metric-torch-mps-float32] __________________

metric = <function zero_one_loss at 0x130ba4700>, array_namespace = 'torch', device = 'mps', dtype_name = 'float32'
check_func = <function check_array_api_binary_classification_metric at 0x14d50c670>

    @pytest.mark.parametrize(
        "array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
    )
    @pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
    def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):
>       check_func(metric, array_namespace, device, dtype_name)

sklearn/metrics/tests/test_common.py:1839: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
sklearn/metrics/tests/test_common.py:1776: in check_array_api_binary_classification_metric
    check_array_api_metric(
sklearn/metrics/tests/test_common.py:1749: in check_array_api_metric
    metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight)
sklearn/utils/_param_validation.py:213: in wrapper
    return func(*args, **kwargs)
sklearn/metrics/_classification.py:1069: in zero_one_loss
    score = accuracy_score(
sklearn/utils/_param_validation.py:186: in wrapper
    return func(*args, **kwargs)
sklearn/metrics/_classification.py:221: in accuracy_score
    return _weighted_sum(score, sample_weight, normalize)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

sample_score = tensor([1., 0., 0., 1.], dtype=torch.float64), sample_weight = tensor([0.0000, 0.1000, 2.0000, 1.0000], device='mps:0')
normalize = True
xp = <module 'array_api_compat.torch' from '/Users/glemaitre/mambaforge/envs/sklearn_dev/lib/python3.10/site-packages/array_api_compat/torch/__init__.py'>

    def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
        # XXX: this function accepts Array API input but returns a Python scalar
        # float. The call to float() is convenient because it removes the need to
        # move back results from device to host memory (e.g. calling `.cpu()` on a
        # torch tensor). However, this might interact in unexpected ways (break?)
        # with lazy Array API implementations. See:
        # https://github.com/data-apis/array-api/issues/642
        if xp is None:
            xp, _ = get_namespace(sample_score)
        if normalize and _is_numpy_namespace(xp):
            sample_score_np = numpy.asarray(sample_score)
            if sample_weight is not None:
                sample_weight_np = numpy.asarray(sample_weight)
            else:
                sample_weight_np = None
            return float(numpy.average(sample_score_np, weights=sample_weight_np))
    
        if not xp.isdtype(sample_score.dtype, "real floating"):
            # We move to cpu device ahead of time since certain devices may not support
            # float64, but we want the same precision for all devices and namespaces.
            sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64)
    
        if sample_weight is not None:
>           sample_weight = xp.asarray(
                sample_weight, dtype=sample_score.dtype, device=device(sample_score)
            )
E           TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

sklearn/utils/_array_api.py:447: TypeError
________________ test_array_api_compliance[zero_one_loss-check_array_api_multiclass_classification_metric-torch-mps-float32] ________________

metric = <function zero_one_loss at 0x130ba4700>, array_namespace = 'torch', device = 'mps', dtype_name = 'float32'
check_func = <function check_array_api_multiclass_classification_metric at 0x14d50c700>

    @pytest.mark.parametrize(
        "array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
    )
    @pytest.mark.parametrize("metric, check_func", yield_metric_checker_combinations())
    def test_array_api_compliance(metric, array_namespace, device, dtype_name, check_func):
>       check_func(metric, array_namespace, device, dtype_name)

sklearn/metrics/tests/test_common.py:1839: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
sklearn/metrics/tests/test_common.py:1805: in check_array_api_multiclass_classification_metric
    check_array_api_metric(
sklearn/metrics/tests/test_common.py:1749: in check_array_api_metric
    metric_xp = metric(y_true_xp, y_pred_xp, sample_weight=sample_weight)
sklearn/utils/_param_validation.py:213: in wrapper
    return func(*args, **kwargs)
sklearn/metrics/_classification.py:1069: in zero_one_loss
    score = accuracy_score(
sklearn/utils/_param_validation.py:186: in wrapper
    return func(*args, **kwargs)
sklearn/metrics/_classification.py:221: in accuracy_score
    return _weighted_sum(score, sample_weight, normalize)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

sample_score = tensor([1., 1., 0., 0.], dtype=torch.float64), sample_weight = tensor([0.0000, 0.1000, 2.0000, 1.0000], device='mps:0')
normalize = True
xp = <module 'array_api_compat.torch' from '/Users/glemaitre/mambaforge/envs/sklearn_dev/lib/python3.10/site-packages/array_api_compat/torch/__init__.py'>

    def _weighted_sum(sample_score, sample_weight, normalize=False, xp=None):
        # XXX: this function accepts Array API input but returns a Python scalar
        # float. The call to float() is convenient because it removes the need to
        # move back results from device to host memory (e.g. calling `.cpu()` on a
        # torch tensor). However, this might interact in unexpected ways (break?)
        # with lazy Array API implementations. See:
        # https://github.com/data-apis/array-api/issues/642
        if xp is None:
            xp, _ = get_namespace(sample_score)
        if normalize and _is_numpy_namespace(xp):
            sample_score_np = numpy.asarray(sample_score)
            if sample_weight is not None:
                sample_weight_np = numpy.asarray(sample_weight)
            else:
                sample_weight_np = None
            return float(numpy.average(sample_score_np, weights=sample_weight_np))
    
        if not xp.isdtype(sample_score.dtype, "real floating"):
            # We move to cpu device ahead of time since certain devices may not support
            # float64, but we want the same precision for all devices and namespaces.
            sample_score = xp.astype(xp.asarray(sample_score, device="cpu"), xp.float64)
    
        if sample_weight is not None:
>           sample_weight = xp.asarray(
                sample_weight, dtype=sample_score.dtype, device=device(sample_score)
            )
E           TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

sklearn/utils/_array_api.py:447: TypeError
========================================================== short test summary info ==========================================================
FAILED sklearn/decomposition/tests/test_pca.py::test_pca_mle_array_api_compliance[PCA(n_components='mle',svd_solver='full')-check_array_api_input-torch-cpu-float32] - AssertionError
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_binary_classification_metric-torch-mps-float32] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[accuracy_score-check_array_api_multiclass_classification_metric-torch-mps-float32] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_binary_classification_metric-torch-mps-float32] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED sklearn/metrics/tests/test_common.py::test_array_api_compliance[zero_one_loss-check_array_api_multiclass_classification_metric-torch-mps-float32] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
================================ 5 failed, 209 passed, 153 skipped, 35329 deselected, 125 warnings in 30.99s ================================

Open to address those in a subsequent PR specific to solving issue for MPS device.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The first AssertionError is present without triggering the MPS support just with pytorch-cpu.

@fcharras
Copy link
Contributor Author

The first one is surprising. I don't think it's related to this PR, that just does some renaming there. My guess is simply that the threshold for error tolerance might be too high.

The second one, I think it is not related either, if you run the test suite on main I think you'll have the same, it does point out an issue in _sample_weight that I have seen before that seems to have not been caught when that part was merged, and that is fixed in #27904 , unfortunately I doesn't have access to a float32 only device to confirm.

@glemaitre
Copy link
Member

OK since this is in main, let's tackle the problem in dedicated PRs.

@glemaitre glemaitre merged commit 77aeb82 into scikit-learn:main Dec 18, 2023
@glemaitre
Copy link
Member

Thanks @fcharras

@fcharras fcharras deleted the FIX/dtype_not_applied_in_array_api_metrics_test branch December 18, 2023 16:06
glemaitre pushed a commit to glemaitre/scikit-learn that referenced this pull request Feb 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants