Skip to content

Improve tests to make them run on variously typed data using the global_dtype fixture #22881

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
jjerphan opened this issue Mar 17, 2022 · 8 comments
Labels
float32 Issues related to support for 32bit data Hard Hard level of difficulty help wanted Meta-issue General issue associated to an identified list of tasks module:test-suite everything related to our tests

Comments

@jjerphan
Copy link
Member

jjerphan commented Mar 17, 2022

Context: the new global_dtype fixture and SKLEARN_RUN_FLOAT32_TESTS environment variable

Introduction of low-level computational routines for 32bit motivated an extension of tests to run them on 32bit.

In this regards, #22690 introduced a new global_dtype fixture as well has the SKLEARN_RUN_FLOAT32_TESTS env. variable to make it possible to run the test on 32bit data.

Running test on 32bit can be done using SKLEARN_RUN_FLOAT32_TESTS=1.

For instance, this run the first global_dtype-parametrised test:

SKLEARN_RUN_FLOAT32_TESTS=1 pytest sklearn/feature_selection/tests/test_mutual_info.py -k test_compute_mi_cc

This allows running tests on 32bit dataset on some CI job, and currently a single CI job is used to run tests on 32bit.

More details about the fixture in the online dev doc for the SKLEARN_TESTS_GLOBAL_RANDOM_SEED env variable:

https://scikit-learn.org/dev/computing/parallelism.html#environment-variables

Guidelines to convert existing tests

  • Not all scikit-learn tests must use this fixture. We must parametrise tests that actually assert closeness of results using assert_allclose. For instance, tests that check for the exception messages raised when passing invalid inputs must not be converted.

  • Tests using np.testing.assert_allclose must now use sklearn.utils._testing.assert_allclose as a drop-in replacement.

  • Check that the dtype of fitted attributes or return values that depend on the dtype of the input datastructure actually have the expected dtype: typically when all inputs are continuous values in float32, it is often (but not always) the case that scikit-learn should carry all arithmetic operations at that precision level and return output arrays with the same precision level. There can be exceptions, in which case they could be made explicit with an inline comment in the test, possibly with a TODO marker when one thing that the current behavior should change (see the related: Preserving dtype for float32 / float64 in transformers #11000 and Estimator check for dtype preservation for regressors #22682 for instance).

  • To avoid having to review huge PRs that impact many files at once and can lead to conflicts, let's open PRs that edit at most one test file at a time. For instance use a title such as:

TST use global_dtype in sklearn/_loss/tests/test_glm_distribution.py

from numpy.testing import assert_allclose

def test_some_function():
    # ...
    rng = np.random.RandomState(0)
    X = rng.randon.rand(n_samples, n_features)
    y = rng.randon.rand(n_samples).astype(global_dtype)
    model.fit(X, y)
    # ...
    y_pred = model.predict(X)
    assert_allclose(y_pred, y_true)

to:

from sklearn.utils._testing import assert_allclose

def test_some_function(global_dtype):
    # ...
    rng = np.random.RandomState(0)
    X = rng.randon.rand(n_samples, n_features).astype(global_dtype, copy=False)
    y = rng.randon.rand(n_samples).astype(global_dtype, copy=False)
    model.fit(X, y)
    # ...
    assert model.fitted_param_.dtype == global_dtype
    y_pred = model.predict(X)
    assert y_pred.dtype == global_dtype
    assert_allclose(y_pred, y_true)

and then check that the test is passing on 32bit datasets

SKLEARN_RUN_FLOAT32_TESTS=1 pytest sklearn/some_module/test/test_some_module.py -k test_some_function

Failures are to be handle on a case-by-case basis.

List of test modules to upgrade

find sklearn -name "test_*.py"
  • sklearn/_loss/tests/test_glm_distribution.py
  • sklearn/_loss/tests/test_link.py
  • sklearn/_loss/tests/test_loss.py
  • sklearn/cluster/tests/test_affinity_propagation.py TST use global_dtype in sklearn/cluster/tests/test_affinity_propagation.py #22667
  • sklearn/cluster/tests/test_bicluster.py
  • sklearn/cluster/tests/test_birch.py TST use global_dtype in sklearn/cluster/tests/test_birch.py #22671
  • sklearn/cluster/tests/test_dbscan.py
  • sklearn/cluster/tests/test_feature_agglomeration.py
  • sklearn/cluster/tests/test_hierarchical.py
  • sklearn/cluster/tests/test_k_means.py
  • sklearn/cluster/tests/test_mean_shift.py TST use global_dtype in sklearn/cluster/tests/test_mean_shift.py #22672
  • sklearn/cluster/tests/test_optics.py ENH Add dtype preservation to LocalOutlierFactor #22665
  • sklearn/cluster/tests/test_spectral.py ENH Add dtype preservation for SpectralClustering #22669
  • sklearn/compose/tests/test_column_transformer.py
  • sklearn/compose/tests/test_target.py
  • sklearn/covariance/tests/test_covariance.py
  • sklearn/covariance/tests/test_elliptic_envelope.py
  • sklearn/covariance/tests/test_graphical_lasso.py
  • sklearn/covariance/tests/test_robust_covariance.py
  • sklearn/cross_decomposition/tests/test_pls.py
  • sklearn/datasets/tests/test_20news.py
  • sklearn/datasets/tests/test_base.py
  • sklearn/datasets/tests/test_california_housing.py
  • sklearn/datasets/tests/test_common.py
  • sklearn/datasets/tests/test_covtype.py
  • sklearn/datasets/tests/test_kddcup99.py
  • sklearn/datasets/tests/test_lfw.py
  • sklearn/datasets/tests/test_olivetti_faces.py
  • sklearn/datasets/tests/test_openml.py
  • sklearn/datasets/tests/test_rcv1.py
  • sklearn/datasets/tests/test_samples_generator.py
  • sklearn/datasets/tests/test_svmlight_format.py
  • sklearn/decomposition/tests/test_dict_learning.py
  • sklearn/decomposition/tests/test_factor_analysis.py
  • sklearn/decomposition/tests/test_fastica.py
  • sklearn/decomposition/tests/test_incremental_pca.py
  • sklearn/decomposition/tests/test_kernel_pca.py
  • sklearn/decomposition/tests/test_nmf.py
  • sklearn/decomposition/tests/test_online_lda.py
  • sklearn/decomposition/tests/test_pca.py
  • sklearn/decomposition/tests/test_sparse_pca.py
  • sklearn/decomposition/tests/test_truncated_svd.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_binning.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_bitset.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_compare_lightgbm.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_grower.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_histogram.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_monotonic_contraints.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_predictor.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_splitting.py
  • sklearn/ensemble/_hist_gradient_boosting/tests/test_warm_start.py
  • sklearn/ensemble/tests/test_bagging.py
  • sklearn/ensemble/tests/test_base.py
  • sklearn/ensemble/tests/test_common.py
  • sklearn/ensemble/tests/test_forest.py
  • sklearn/ensemble/tests/test_gradient_boosting_loss_functions.py
  • sklearn/ensemble/tests/test_gradient_boosting.py
  • sklearn/ensemble/tests/test_iforest.py
  • sklearn/ensemble/tests/test_stacking.py
  • sklearn/ensemble/tests/test_voting.py
  • sklearn/ensemble/tests/test_weight_boosting.py
  • sklearn/experimental/tests/test_enable_hist_gradient_boosting.py
  • sklearn/experimental/tests/test_enable_iterative_imputer.py
  • sklearn/experimental/tests/test_enable_successive_halving.py
  • sklearn/feature_extraction/tests/test_dict_vectorizer.py
  • sklearn/feature_extraction/tests/test_feature_hasher.py
  • sklearn/feature_extraction/tests/test_image.py
  • sklearn/feature_extraction/tests/test_text.py
  • sklearn/feature_selection/tests/test_base.py
  • sklearn/feature_selection/tests/test_chi2.py
  • sklearn/feature_selection/tests/test_feature_select.py
  • sklearn/feature_selection/tests/test_from_model.py
  • sklearn/feature_selection/tests/test_mutual_info.py TST use global_dtype in sklearn/feature_selection/tests/test_mutual_info.py #22677
  • sklearn/feature_selection/tests/test_rfe.py
  • sklearn/feature_selection/tests/test_sequential.py
  • sklearn/feature_selection/tests/test_variance_threshold.py
  • sklearn/gaussian_process/tests/test_gpc.py
  • sklearn/gaussian_process/tests/test_gpr.py
  • sklearn/gaussian_process/tests/test_kernels.py
  • sklearn/impute/tests/test_base.py
  • sklearn/impute/tests/test_common.py
  • sklearn/impute/tests/test_impute.py
  • sklearn/impute/tests/test_knn.py
  • sklearn/inspection/_plot/tests/test_plot_partial_dependence.py
  • sklearn/inspection/tests/test_partial_dependence.py
  • sklearn/inspection/tests/test_permutation_importance.py
  • sklearn/linear_model/_glm/tests/test_glm.py
  • sklearn/linear_model/_glm/tests/test_link.py
  • sklearn/linear_model/tests/test_base.py
  • sklearn/linear_model/tests/test_bayes.py
  • sklearn/linear_model/tests/test_common.py
  • sklearn/linear_model/tests/test_coordinate_descent.py
  • sklearn/linear_model/tests/test_huber.py
  • sklearn/linear_model/tests/test_least_angle.py
  • sklearn/linear_model/tests/test_linear_loss.py
  • sklearn/linear_model/tests/test_logistic.py
  • sklearn/linear_model/tests/test_omp.py
  • sklearn/linear_model/tests/test_passive_aggressive.py
  • sklearn/linear_model/tests/test_perceptron.py
  • sklearn/linear_model/tests/test_quantile.py
  • sklearn/linear_model/tests/test_ransac.py
  • sklearn/linear_model/tests/test_ridge.py
  • sklearn/linear_model/tests/test_sag.py
  • sklearn/linear_model/tests/test_sgd.py
  • sklearn/linear_model/tests/test_sparse_coordinate_descent.py
  • sklearn/linear_model/tests/test_theil_sen.py
  • sklearn/manifold/tests/test_isomap.py TST use global_dtype in sklearn/manifold/tests/test_isomap.py #22673
  • sklearn/manifold/tests/test_locally_linear.py TST use global_dtype in sklearn/manifold/tests/test_locally_linear.py #22676
  • sklearn/manifold/tests/test_mds.py
  • sklearn/manifold/tests/test_spectral_embedding.py
  • sklearn/manifold/tests/test_t_sne.py TST use global_dtype in sklearn/manifold/tests/test_t_sne.py #22675
  • sklearn/metrics/_plot/tests/test_base.py
  • sklearn/metrics/_plot/tests/test_common_curve_display.py
  • sklearn/metrics/_plot/tests/test_confusion_matrix_display.py
  • sklearn/metrics/_plot/tests/test_det_curve_display.py
  • sklearn/metrics/_plot/tests/test_plot_confusion_matrix.py
  • sklearn/metrics/_plot/tests/test_plot_curve_common.py
  • sklearn/metrics/_plot/tests/test_plot_det_curve.py
  • sklearn/metrics/_plot/tests/test_plot_precision_recall.py
  • sklearn/metrics/_plot/tests/test_plot_roc_curve.py
  • sklearn/metrics/_plot/tests/test_precision_recall_display.py
  • sklearn/metrics/_plot/tests/test_roc_curve_display.py
  • sklearn/metrics/cluster/tests/test_bicluster.py
  • sklearn/metrics/cluster/tests/test_common.py
  • sklearn/metrics/cluster/tests/test_supervised.py
  • sklearn/metrics/cluster/tests/test_unsupervised.py
  • sklearn/metrics/tests/test_classification.py
  • sklearn/metrics/tests/test_common.py
  • sklearn/metrics/tests/test_dist_metrics.py
  • sklearn/metrics/tests/test_pairwise_distances_reduction.py
  • sklearn/metrics/tests/test_pairwise.py TST use global_dtype in sklearn/metrics/tests/test_pairwise.py #22666
  • sklearn/metrics/tests/test_ranking.py
  • sklearn/metrics/tests/test_regression.py
  • sklearn/metrics/tests/test_score_objects.py
  • sklearn/mixture/tests/test_bayesian_mixture.py
  • sklearn/mixture/tests/test_gaussian_mixture.py
  • sklearn/mixture/tests/test_mixture.py
  • sklearn/model_selection/tests/test_search.py
  • sklearn/model_selection/tests/test_split.py
  • sklearn/model_selection/tests/test_successive_halving.py
  • sklearn/model_selection/tests/test_validation.py
  • sklearn/neighbors/tests/test_ball_tree.py
  • sklearn/neighbors/tests/test_graph.py
  • sklearn/neighbors/tests/test_kd_tree.py
  • sklearn/neighbors/tests/test_kde.py
  • sklearn/neighbors/tests/test_lof.py ENH Add dtype preservation to LocalOutlierFactor #22665
  • sklearn/neighbors/tests/test_nca.py
  • sklearn/neighbors/tests/test_nearest_centroid.py
  • sklearn/neighbors/tests/test_neighbors_pipeline.py
  • sklearn/neighbors/tests/test_neighbors_tree.py
  • sklearn/neighbors/tests/test_neighbors.py TST use global_dtype in sklearn/neighbors/tests/test_neighbors.py #22663
  • sklearn/neighbors/tests/test_quad_tree.py
  • sklearn/neural_network/tests/test_base.py
  • sklearn/neural_network/tests/test_mlp.py
  • sklearn/neural_network/tests/test_rbm.py
  • sklearn/neural_network/tests/test_stochastic_optimizers.py
  • sklearn/preprocessing/tests/test_common.py
  • sklearn/preprocessing/tests/test_data.py
  • sklearn/preprocessing/tests/test_discretization.py
  • sklearn/preprocessing/tests/test_encoders.py
  • sklearn/preprocessing/tests/test_function_transformer.py
  • sklearn/preprocessing/tests/test_label.py
  • sklearn/preprocessing/tests/test_polynomial.py
  • sklearn/semi_supervised/tests/test_label_propagation.py
  • sklearn/semi_supervised/tests/test_self_training.py
  • sklearn/svm/tests/test_bounds.py
  • sklearn/svm/tests/test_sparse.py
  • sklearn/svm/tests/test_svm.py
  • sklearn/tests/test_base.py
  • sklearn/tests/test_build.py
  • sklearn/tests/test_calibration.py
  • sklearn/tests/test_check_build.py
  • sklearn/tests/test_common.py
  • sklearn/tests/test_config.py
  • sklearn/tests/test_discriminant_analysis.py
  • sklearn/tests/test_docstring_parameters.py
  • sklearn/tests/test_docstrings.py
  • sklearn/tests/test_dummy.py
  • sklearn/tests/test_init.py
  • sklearn/tests/test_isotonic.py
  • sklearn/tests/test_kernel_approximation.py
  • sklearn/tests/test_kernel_ridge.py
  • sklearn/tests/test_metaestimators.py
  • sklearn/tests/test_min_dependencies_readme.py
  • sklearn/tests/test_multiclass.py
  • sklearn/tests/test_multioutput.py
  • sklearn/tests/test_naive_bayes.py
  • sklearn/tests/test_pipeline.py
  • sklearn/tests/test_random_projection.py
  • sklearn/tree/tests/test_export.py
  • sklearn/tree/tests/test_reingold_tilford.py
  • sklearn/tree/tests/test_tree.py
  • sklearn/utils/tests/test_arpack.py
  • sklearn/utils/tests/test_arrayfuncs.py
  • sklearn/utils/tests/test_class_weight.py
  • sklearn/utils/tests/test_cython_blas.py
  • sklearn/utils/tests/test_cython_templating.py
  • sklearn/utils/tests/test_deprecation.py
  • sklearn/utils/tests/test_encode.py
  • sklearn/utils/tests/test_estimator_checks.py
  • sklearn/utils/tests/test_estimator_html_repr.py
  • sklearn/utils/tests/test_extmath.py
  • sklearn/utils/tests/test_fast_dict.py
  • sklearn/utils/tests/test_fixes.py
  • sklearn/utils/tests/test_graph.py
  • sklearn/utils/tests/test_metaestimators.py
  • sklearn/utils/tests/test_mocking.py
  • sklearn/utils/tests/test_multiclass.py
  • sklearn/utils/tests/test_murmurhash.py
  • sklearn/utils/tests/test_optimize.py
  • sklearn/utils/tests/test_parallel.py
  • sklearn/utils/tests/test_pprint.py
  • sklearn/utils/tests/test_random.py
  • sklearn/utils/tests/test_readonly_wrapper.py
  • sklearn/utils/tests/test_seq_dataset.py
  • sklearn/utils/tests/test_shortest_path.py
  • sklearn/utils/tests/test_show_versions.py
  • sklearn/utils/tests/test_sparsefuncs.py
  • sklearn/utils/tests/test_stats.py
  • sklearn/utils/tests/test_tags.py
  • sklearn/utils/tests/test_testing.py
  • sklearn/utils/tests/test_utils.py
  • sklearn/utils/tests/test_validation.py
  • sklearn/utils/tests/test_weight_vector.py

Note that some of those files might not have any test to update.

@github-actions github-actions bot added the Needs Triage Issue requires triage label Mar 17, 2022
@ogrisel
Copy link
Member

ogrisel commented Mar 17, 2022

I have tried to expand the description of this issue, but github fails to take my update into account for some reason.

Here is a copy of the edited issue description:

https://gist.github.com/ogrisel/d6867ca3c8c64e0b3d789f5a2f9ce067

@jjerphan
Copy link
Member Author

Thanks, I just pasted your snippet in the description to replace the text.

@jjerphan
Copy link
Member Author

jjerphan commented Mar 18, 2022

Regarding dtype preservation for fitted attributes, wouldn't it be better to introduce global checks for it with tags? I feel that we might spend a lot of time duplicating checks in many tests, potentially inserting a lot of TODO.

To me it seems that this would better be addressed by another meta-issue similar to #11000.

What do you think?

cc @glemaitre

@jjerphan jjerphan added the float32 Issues related to support for 32bit data label Mar 18, 2022
@jeremiedbb jeremiedbb added Hard Hard level of difficulty and removed good first issue Easy with clear instructions to resolve labels Mar 18, 2022
@jeremiedbb
Copy link
Member

I changed it to hard because it requires some knowledge to figure out which test should use the fixture and which test should not.

@jeremiedbb
Copy link
Member

I think the list should be updated to remove all estimators that do not preserve the dype yet.
If an estimator doesn't doesn't preserve the dtype, it means that it converts the input when the dtype is not float64 (usually) which means that we end up repeating the same test when we parametrize over the global_dtype.

@jjerphan jjerphan removed the Needs Triage Issue requires triage label Mar 18, 2022
@adam2392
Copy link
Member

@jjerphan just tried this for sklearn/tree/tests/test_tree.py and sklearn/ensemble/tests/test_forest.py and I don't think it works because the TreeBuilder has this conversion inside _check_input, which always forces the X input dtype

        elif X.dtype != DTYPE:  # where DTYPE is float32
            # since we have to copy we will make it fortran for efficiency
            X = np.asfortranarray(X, dtype=DTYPE)

@ogrisel
Copy link
Member

ogrisel commented Mar 22, 2022

I changed it to hard because it requires some knowledge to figure out which test should use the fixture and which test should not.

I extended the tests of #22806 using a combination of global_dtype and the global_random_seed fixtures and it revealed numerical problem that would not be visible with the default seed. I already fixed the most obvious but I am not 100% sure if this is good enough or not. So indeed I anticipate those PRs to be hard on average.

@jjerphan
Copy link
Member Author

@adam2392: some estimators aren't preserving the provided input dtype.

In that case, it still makes sense to extend the tests adding TODO, as explained in the description of this issue:

Check that the dtype of fitted attributes or return values that depend on the dtype of the input datastructure actually have the expected dtype: typically when all inputs are continuous values in float32, it is often (but not always) the case that scikit-learn should carry all arithmetic operations at that precision level and return output arrays with the same precision level. There can be exceptions, in which case they could be made explicit with an inline comment in the test, possibly with a TODO marker when one thing that the current behavior should change (see the related: #11000 and #22682 for instance).

@cmarmo cmarmo added the module:test-suite everything related to our tests label Sep 16, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
float32 Issues related to support for 32bit data Hard Hard level of difficulty help wanted Meta-issue General issue associated to an identified list of tasks module:test-suite everything related to our tests
Projects
None yet
Development

No branches or pull requests

5 participants