From a4803c5132ce3591f5e9b4338279676d91b89101 Mon Sep 17 00:00:00 2001 From: Yao Xiao <108576690+Charlie-XIAO@users.noreply.github.com> Date: Thu, 28 Sep 2023 19:19:36 +0800 Subject: [PATCH 01/12] FIX make dataset fetchers accept `os.Pathlike` for `data_home` (#27468) Co-authored-by: Guillaume Lemaitre --- sklearn/datasets/_base.py | 4 ++-- sklearn/datasets/_california_housing.py | 6 +++--- sklearn/datasets/_covtype.py | 4 ++-- sklearn/datasets/_kddcup99.py | 4 ++-- sklearn/datasets/_lfw.py | 10 +++++----- sklearn/datasets/_olivetti_faces.py | 6 +++--- sklearn/datasets/_openml.py | 6 +++--- sklearn/datasets/_rcv1.py | 6 +++--- sklearn/datasets/_species_distributions.py | 6 +++--- sklearn/datasets/_twenty_newsgroups.py | 8 ++++---- sklearn/datasets/tests/test_base.py | 18 +++++++++++++++++- 11 files changed, 47 insertions(+), 31 deletions(-) diff --git a/sklearn/datasets/_base.py b/sklearn/datasets/_base.py index b2d198ecf8c2f..5675798137824 100644 --- a/sklearn/datasets/_base.py +++ b/sklearn/datasets/_base.py @@ -57,7 +57,7 @@ def get_data_home(data_home=None) -> str: ---------- data_home : str or path-like, default=None The path to scikit-learn data directory. If `None`, the default path - is `~/sklearn_learn_data`. + is `~/scikit_learn_data`. Returns ------- @@ -84,7 +84,7 @@ def clear_data_home(data_home=None): ---------- data_home : str or path-like, default=None The path to scikit-learn data directory. If `None`, the default path - is `~/sklearn_learn_data`. + is `~/scikit_learn_data`. """ data_home = get_data_home(data_home) shutil.rmtree(data_home) diff --git a/sklearn/datasets/_california_housing.py b/sklearn/datasets/_california_housing.py index b48e7e10bdc4b..3153f0dd03f72 100644 --- a/sklearn/datasets/_california_housing.py +++ b/sklearn/datasets/_california_housing.py @@ -23,7 +23,7 @@ import logging import tarfile -from os import makedirs, remove +from os import PathLike, makedirs, remove from os.path import exists import joblib @@ -53,7 +53,7 @@ @validate_params( { - "data_home": [str, None], + "data_home": [str, PathLike, None], "download_if_missing": ["boolean"], "return_X_y": ["boolean"], "as_frame": ["boolean"], @@ -76,7 +76,7 @@ def fetch_california_housing( Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_covtype.py b/sklearn/datasets/_covtype.py index 557899bc88e97..7620e08c5ec92 100644 --- a/sklearn/datasets/_covtype.py +++ b/sklearn/datasets/_covtype.py @@ -65,7 +65,7 @@ @validate_params( { - "data_home": [str, None], + "data_home": [str, os.PathLike, None], "download_if_missing": ["boolean"], "random_state": ["random_state"], "shuffle": ["boolean"], @@ -98,7 +98,7 @@ def fetch_covtype( Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_kddcup99.py b/sklearn/datasets/_kddcup99.py index 17c49161c3bc2..444bd01737901 100644 --- a/sklearn/datasets/_kddcup99.py +++ b/sklearn/datasets/_kddcup99.py @@ -50,7 +50,7 @@ @validate_params( { "subset": [StrOptions({"SA", "SF", "http", "smtp"}), None], - "data_home": [str, None], + "data_home": [str, os.PathLike, None], "shuffle": ["boolean"], "random_state": ["random_state"], "percent10": ["boolean"], @@ -92,7 +92,7 @@ def fetch_kddcup99( To return the corresponding classical subsets of kddcup 99. If None, return the entire kddcup 99 dataset. - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_lfw.py b/sklearn/datasets/_lfw.py index 345f56e89a03b..d06d29f21d0a5 100644 --- a/sklearn/datasets/_lfw.py +++ b/sklearn/datasets/_lfw.py @@ -10,7 +10,7 @@ import logging from numbers import Integral, Real -from os import listdir, makedirs, remove +from os import PathLike, listdir, makedirs, remove from os.path import exists, isdir, join import numpy as np @@ -234,7 +234,7 @@ def _fetch_lfw_people( @validate_params( { - "data_home": [str, None], + "data_home": [str, PathLike, None], "funneled": ["boolean"], "resize": [Interval(Real, 0, None, closed="neither"), None], "min_faces_per_person": [Interval(Integral, 0, None, closed="left"), None], @@ -272,7 +272,7 @@ def fetch_lfw_people( Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. @@ -431,7 +431,7 @@ def _fetch_lfw_pairs( @validate_params( { "subset": [StrOptions({"train", "test", "10_folds"})], - "data_home": [str, None], + "data_home": [str, PathLike, None], "funneled": ["boolean"], "resize": [Interval(Real, 0, None, closed="neither"), None], "color": ["boolean"], @@ -480,7 +480,7 @@ def fetch_lfw_pairs( official evaluation set that is meant to be used with a 10-folds cross validation. - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_olivetti_faces.py b/sklearn/datasets/_olivetti_faces.py index 51710faccc417..8e1b3c91e254b 100644 --- a/sklearn/datasets/_olivetti_faces.py +++ b/sklearn/datasets/_olivetti_faces.py @@ -13,7 +13,7 @@ # Copyright (c) 2011 David Warde-Farley # License: BSD 3 clause -from os import makedirs, remove +from os import PathLike, makedirs, remove from os.path import exists import joblib @@ -36,7 +36,7 @@ @validate_params( { - "data_home": [str, None], + "data_home": [str, PathLike, None], "shuffle": ["boolean"], "random_state": ["random_state"], "download_if_missing": ["boolean"], @@ -67,7 +67,7 @@ def fetch_olivetti_faces( Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_openml.py b/sklearn/datasets/_openml.py index 1c36dc8a25ce1..c9d09dc3ce46a 100644 --- a/sklearn/datasets/_openml.py +++ b/sklearn/datasets/_openml.py @@ -749,7 +749,7 @@ def _valid_data_column_names(features_list, target_columns): "name": [str, None], "version": [Interval(Integral, 1, None, closed="left"), StrOptions({"active"})], "data_id": [Interval(Integral, 1, None, closed="left"), None], - "data_home": [str, None], + "data_home": [str, os.PathLike, None], "target_column": [str, list, None], "cache": [bool], "return_X_y": [bool], @@ -769,7 +769,7 @@ def fetch_openml( *, version: Union[str, int] = "active", data_id: Optional[int] = None, - data_home: Optional[str] = None, + data_home: Optional[Union[str, os.PathLike]] = None, target_column: Optional[Union[str, List]] = "default-target", cache: bool = True, return_X_y: bool = False, @@ -815,7 +815,7 @@ def fetch_openml( dataset. If data_id is not given, name (and potential version) are used to obtain a dataset. - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the data sets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_rcv1.py b/sklearn/datasets/_rcv1.py index a807d8e311466..d9f392d872216 100644 --- a/sklearn/datasets/_rcv1.py +++ b/sklearn/datasets/_rcv1.py @@ -10,7 +10,7 @@ import logging from gzip import GzipFile -from os import makedirs, remove +from os import PathLike, makedirs, remove from os.path import exists, join import joblib @@ -74,7 +74,7 @@ @validate_params( { - "data_home": [str, None], + "data_home": [str, PathLike, None], "subset": [StrOptions({"train", "test", "all"})], "download_if_missing": ["boolean"], "random_state": ["random_state"], @@ -111,7 +111,7 @@ def fetch_rcv1( Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_species_distributions.py b/sklearn/datasets/_species_distributions.py index 0bfc4bb0fdaf5..a1e654d41e071 100644 --- a/sklearn/datasets/_species_distributions.py +++ b/sklearn/datasets/_species_distributions.py @@ -39,7 +39,7 @@ import logging from io import BytesIO -from os import makedirs, remove +from os import PathLike, makedirs, remove from os.path import exists import joblib @@ -136,7 +136,7 @@ def construct_grids(batch): @validate_params( - {"data_home": [str, None], "download_if_missing": ["boolean"]}, + {"data_home": [str, PathLike, None], "download_if_missing": ["boolean"]}, prefer_skip_nested_validation=True, ) def fetch_species_distributions(*, data_home=None, download_if_missing=True): @@ -146,7 +146,7 @@ def fetch_species_distributions(*, data_home=None, download_if_missing=True): Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify another download and cache folder for the datasets. By default all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/_twenty_newsgroups.py b/sklearn/datasets/_twenty_newsgroups.py index 637cf8e4fc8d4..5973e998c34b9 100644 --- a/sklearn/datasets/_twenty_newsgroups.py +++ b/sklearn/datasets/_twenty_newsgroups.py @@ -153,7 +153,7 @@ def strip_newsgroup_footer(text): @validate_params( { - "data_home": [str, None], + "data_home": [str, os.PathLike, None], "subset": [StrOptions({"train", "test", "all"})], "categories": ["array-like", None], "shuffle": ["boolean"], @@ -191,7 +191,7 @@ def fetch_20newsgroups( Parameters ---------- - data_home : str, default=None + data_home : str or path-like, default=None Specify a download and cache folder for the datasets. If None, all scikit-learn data is stored in '~/scikit_learn_data' subfolders. @@ -351,7 +351,7 @@ def fetch_20newsgroups( { "subset": [StrOptions({"train", "test", "all"})], "remove": [tuple], - "data_home": [str, None], + "data_home": [str, os.PathLike, None], "download_if_missing": ["boolean"], "return_X_y": ["boolean"], "normalize": ["boolean"], @@ -411,7 +411,7 @@ def fetch_20newsgroups_vectorized( ends of posts that look like signatures, and 'quotes' removes lines that appear to be quoting another post. - data_home : str, default=None + data_home : str or path-like, default=None Specify an download and cache folder for the datasets. If None, all scikit-learn data is stored in '~/scikit_learn_data' subfolders. diff --git a/sklearn/datasets/tests/test_base.py b/sklearn/datasets/tests/test_base.py index f31f20636c0c1..f84c275d67cf9 100644 --- a/sklearn/datasets/tests/test_base.py +++ b/sklearn/datasets/tests/test_base.py @@ -3,6 +3,7 @@ import tempfile import warnings from functools import partial +from pathlib import Path from pickle import dumps, loads import numpy as np @@ -31,6 +32,16 @@ from sklearn.utils.fixes import _is_resource +class _DummyPath: + """Minimal class that implements the os.PathLike interface.""" + + def __init__(self, path): + self.path = path + + def __fspath__(self): + return self.path + + def _remove_dir(path): if os.path.isdir(path): shutil.rmtree(path) @@ -67,13 +78,18 @@ def test_category_dir_2(load_files_root): _remove_dir(test_category_dir2) -def test_data_home(data_home): +@pytest.mark.parametrize("path_container", [None, Path, _DummyPath]) +def test_data_home(path_container, data_home): # get_data_home will point to a pre-existing folder + if path_container is not None: + data_home = path_container(data_home) data_home = get_data_home(data_home=data_home) assert data_home == data_home assert os.path.exists(data_home) # clear_data_home will delete both the content and the folder it-self + if path_container is not None: + data_home = path_container(data_home) clear_data_home(data_home=data_home) assert not os.path.exists(data_home) From c39c2bb697e887a0fe17df164789bd70c536f551 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Fri, 13 Oct 2023 10:47:11 +0200 Subject: [PATCH 02/12] FIX validate properly zero_division=np.nan when used in parallel processing (#27573) --- sklearn/metrics/_classification.py | 18 ++++++++----- sklearn/metrics/tests/test_classification.py | 27 ++++++++++++++++++++ sklearn/utils/_param_validation.py | 3 +++ sklearn/utils/tests/test_param_validation.py | 3 +++ 4 files changed, 45 insertions(+), 6 deletions(-) diff --git a/sklearn/metrics/_classification.py b/sklearn/metrics/_classification.py index f916b7f86da38..0819abb463e2b 100644 --- a/sklearn/metrics/_classification.py +++ b/sklearn/metrics/_classification.py @@ -1079,7 +1079,8 @@ def zero_one_loss(y_true, y_pred, *, normalize=True, sample_weight=None): ], "sample_weight": ["array-like", None], "zero_division": [ - Options(Real, {0.0, 1.0, np.nan}), + Options(Real, {0.0, 1.0}), + "nan", StrOptions({"warn"}), ], }, @@ -1260,7 +1261,8 @@ def f1_score( ], "sample_weight": ["array-like", None], "zero_division": [ - Options(Real, {0.0, 1.0, np.nan}), + Options(Real, {0.0, 1.0}), + "nan", StrOptions({"warn"}), ], }, @@ -1542,7 +1544,8 @@ def _check_set_wise_labels(y_true, y_pred, average, labels, pos_label): "warn_for": [list, tuple, set], "sample_weight": ["array-like", None], "zero_division": [ - Options(Real, {0.0, 1.0, np.nan}), + Options(Real, {0.0, 1.0}), + "nan", StrOptions({"warn"}), ], }, @@ -1979,7 +1982,8 @@ class after being classified as negative. This is the case when the ], "sample_weight": ["array-like", None], "zero_division": [ - Options(Real, {0.0, 1.0, np.nan}), + Options(Real, {0.0, 1.0}), + "nan", StrOptions({"warn"}), ], }, @@ -2149,7 +2153,8 @@ def precision_score( ], "sample_weight": ["array-like", None], "zero_division": [ - Options(Real, {0.0, 1.0, np.nan}), + Options(Real, {0.0, 1.0}), + "nan", StrOptions({"warn"}), ], }, @@ -2412,7 +2417,8 @@ def balanced_accuracy_score(y_true, y_pred, *, sample_weight=None, adjusted=Fals "digits": [Interval(Integral, 0, None, closed="left")], "output_dict": ["boolean"], "zero_division": [ - Options(Real, {0.0, 1.0, np.nan}), + Options(Real, {0.0, 1.0}), + "nan", StrOptions({"warn"}), ], }, diff --git a/sklearn/metrics/tests/test_classification.py b/sklearn/metrics/tests/test_classification.py index cfcb08a312443..afa3b90d5e8a9 100644 --- a/sklearn/metrics/tests/test_classification.py +++ b/sklearn/metrics/tests/test_classification.py @@ -27,6 +27,7 @@ hinge_loss, jaccard_score, log_loss, + make_scorer, matthews_corrcoef, multilabel_confusion_matrix, precision_recall_fscore_support, @@ -35,7 +36,9 @@ zero_one_loss, ) from sklearn.metrics._classification import _check_targets +from sklearn.model_selection import cross_val_score from sklearn.preprocessing import LabelBinarizer, label_binarize +from sklearn.tree import DecisionTreeClassifier from sklearn.utils._mocking import MockDataFrame from sklearn.utils._testing import ( assert_allclose, @@ -2802,3 +2805,27 @@ def test_classification_metric_pos_label_types(metric, classes): y_pred = y_true.copy() result = metric(y_true, y_pred, pos_label=pos_label) assert not np.any(np.isnan(result)) + + +@pytest.mark.parametrize( + "scoring", + [ + make_scorer(f1_score, zero_division=np.nan), + make_scorer(fbeta_score, beta=2, zero_division=np.nan), + make_scorer(precision_score, zero_division=np.nan), + make_scorer(recall_score, zero_division=np.nan), + ], +) +def test_classification_metric_division_by_zero_nan_validaton(scoring): + """Check that we validate `np.nan` properly for classification metrics. + + With `n_jobs=2` in cross-validation, the `np.nan` used for the singleton will be + different in the sub-process and we should not use the `is` operator but + `math.isnan`. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/27563 + """ + X, y = datasets.make_classification(random_state=0) + classifier = DecisionTreeClassifier(max_depth=3, random_state=0).fit(X, y) + cross_val_score(classifier, X, y, scoring=scoring, n_jobs=2, error_score="raise") diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index 0e30627ab06cc..bf063a1945621 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -46,6 +46,7 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): - the string "boolean" - the string "verbose" - the string "cv_object" + - the string "nan" - a MissingValues object representing markers for missing values - a HasMethods object, representing method(s) an object must have - a Hidden object, representing a constraint not meant to be exposed to the user @@ -137,6 +138,8 @@ def make_constraint(constraint): constraint = make_constraint(constraint.constraint) constraint.hidden = True return constraint + if isinstance(constraint, str) and constraint == "nan": + return _NanConstraint() raise ValueError(f"Unknown constraint type: {constraint}") diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index ab6465538b863..2af84707cd2ed 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -23,6 +23,7 @@ _CVObjects, _InstancesOf, _IterablesNotString, + _NanConstraint, _NoneConstraint, _PandasNAConstraint, _RandomStates, @@ -387,6 +388,7 @@ def test_generate_valid_param(constraint): (Real, 0.5), ("boolean", False), ("verbose", 1), + ("nan", np.nan), (MissingValues(), -1), (MissingValues(), -1.0), (MissingValues(), None), @@ -420,6 +422,7 @@ def test_is_satisfied_by(constraint_declaration, value): (MissingValues(numeric_only=True), MissingValues), (HasMethods("fit"), HasMethods), ("cv_object", _CVObjects), + ("nan", _NanConstraint), ], ) def test_make_constraint(constraint_declaration, expected_constraint_class): From 425d3d22961c41aa7c8dcb37b33ba4f3106b6723 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 16 Oct 2023 15:26:36 +0200 Subject: [PATCH 03/12] FIX make sure that KernelPCA works with pandas output and arpack solver (#27583) --- sklearn/decomposition/_kernel_pca.py | 2 +- sklearn/decomposition/tests/test_kernel_pca.py | 15 ++++++++++++++- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/sklearn/decomposition/_kernel_pca.py b/sklearn/decomposition/_kernel_pca.py index ccf79e896f210..800b472a9b3a6 100644 --- a/sklearn/decomposition/_kernel_pca.py +++ b/sklearn/decomposition/_kernel_pca.py @@ -432,7 +432,7 @@ def fit(self, X, y=None): raise ValueError("Cannot fit_inverse_transform with a precomputed kernel.") X = self._validate_data(X, accept_sparse="csr", copy=self.copy_X) self.gamma_ = 1 / X.shape[1] if self.gamma is None else self.gamma - self._centerer = KernelCenterer() + self._centerer = KernelCenterer().set_output(transform="default") K = self._get_kernel(X) self._fit_transform(K) diff --git a/sklearn/decomposition/tests/test_kernel_pca.py b/sklearn/decomposition/tests/test_kernel_pca.py index 3c95454749b4a..fdaa71314f43f 100644 --- a/sklearn/decomposition/tests/test_kernel_pca.py +++ b/sklearn/decomposition/tests/test_kernel_pca.py @@ -4,7 +4,8 @@ import pytest import scipy.sparse as sp -from sklearn.datasets import make_blobs, make_circles +import sklearn +from sklearn.datasets import load_iris, make_blobs, make_circles from sklearn.decomposition import PCA, KernelPCA from sklearn.exceptions import NotFittedError from sklearn.linear_model import Perceptron @@ -550,3 +551,15 @@ def test_kernel_pca_inverse_correct_gamma(): X2_recon = kpca2.inverse_transform(kpca1.transform(X)) assert_allclose(X1_recon, X2_recon) + + +def test_kernel_pca_pandas_output(): + """Check that KernelPCA works with pandas output when the solver is arpack. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/27579 + """ + pytest.importorskip("pandas") + X, _ = load_iris(as_frame=True, return_X_y=True) + with sklearn.config_context(transform_output="pandas"): + KernelPCA(n_components=2, eigen_solver="arpack").fit_transform(X) From 1f0ba3242494e4cdbdf28910fc75d5b2f4f93fe2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Est=C3=A8ve?= Date: Tue, 17 Oct 2023 14:25:17 +0200 Subject: [PATCH 04/12] FIX Make decision tree pickles deterministic (#27580) Co-authored-by: Olivier Grisel --- doc/whats_new/v1.3.rst | 17 +++++++++++++++++ sklearn/tree/_tree.pyx | 4 +++- sklearn/tree/tests/test_tree.py | 13 +++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index 4999aba4c2e71..d6ea7352d4f3f 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -2,6 +2,23 @@ .. currentmodule:: sklearn +.. _changes_1_3_2: + +Version 1.3.2 +============= + +**October 2023** + +Changelog +--------- + +:mod:`sklearn.tree` +................... + +- |Fix| Do not leak data via non-initialized memory in decision tree pickle files and make + the generation of those files deterministic. :pr:`27580` by :user:`Loïc Estève `. + + .. _changes_1_3_1: Version 1.3.1 diff --git a/sklearn/tree/_tree.pyx b/sklearn/tree/_tree.pyx index 22d3d38ce981a..c843ad7500480 100644 --- a/sklearn/tree/_tree.pyx +++ b/sklearn/tree/_tree.pyx @@ -771,11 +771,13 @@ cdef class Tree: safe_realloc(&self.nodes, capacity) safe_realloc(&self.value, capacity * self.value_stride) - # value memory is initialised to 0 to enable classifier argmax if capacity > self.capacity: + # value memory is initialised to 0 to enable classifier argmax memset((self.value + self.capacity * self.value_stride), 0, (capacity - self.capacity) * self.value_stride * sizeof(double)) + # node memory is initialised to 0 to ensure deterministic pickle (padding in Node struct) + memset((self.nodes + self.capacity), 0, (capacity - self.capacity) * sizeof(Node)) # if capacity smaller than node_count, adjust the counter if capacity < self.node_count: diff --git a/sklearn/tree/tests/test_tree.py b/sklearn/tree/tests/test_tree.py index 82c68b1327296..3dac1b38693c0 100644 --- a/sklearn/tree/tests/test_tree.py +++ b/sklearn/tree/tests/test_tree.py @@ -2626,3 +2626,16 @@ def test_sample_weight_non_uniform(make_data, Tree): tree_samples_removed.fit(X[1::2, :], y[1::2]) assert_allclose(tree_samples_removed.predict(X), tree_with_sw.predict(X)) + + +def test_deterministic_pickle(): + # Non-regression test for: + # https://github.com/scikit-learn/scikit-learn/issues/27268 + # Uninitialised memory would lead to the two pickle strings being different. + tree1 = DecisionTreeClassifier(random_state=0).fit(iris.data, iris.target) + tree2 = DecisionTreeClassifier(random_state=0).fit(iris.data, iris.target) + + pickle1 = pickle.dumps(tree1) + pickle2 = pickle.dumps(tree2) + + assert pickle1 == pickle2 From fefeba49cdb6a5f5397dde2680e31eb2561cd5ea Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 17 Oct 2023 14:54:36 +0200 Subject: [PATCH 05/12] REL bump to 1.3.2 --- sklearn/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/__init__.py b/sklearn/__init__.py index 8df2c91fb2ca9..48d907fa5ad23 100644 --- a/sklearn/__init__.py +++ b/sklearn/__init__.py @@ -38,7 +38,7 @@ # Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer. # 'X.Y.dev0' is the canonical version of 'X.Y.dev' # -__version__ = "1.3.1" +__version__ = "1.3.2" # On OSX, we can get a runtime error due to multiple OpenMP libraries loaded From dcebb966e1bfbff6dce0824af0f173b65e31a87a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 17 Oct 2023 14:55:19 +0200 Subject: [PATCH 06/12] [cd build][azure parallel] trigger ci/cd From 7cd53e1ddf69a9d500b9fef3308e794acf35e1fe Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 18 Oct 2023 11:10:51 +0200 Subject: [PATCH 07/12] DOC add version 1.3.2 into landing page (#27604) --- doc/templates/index.html | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/doc/templates/index.html b/doc/templates/index.html index 40f58f7bbefe3..b41ede9382a49 100644 --- a/doc/templates/index.html +++ b/doc/templates/index.html @@ -169,7 +169,9 @@

News

  • On-going development: What's new (Changelog)
  • -
  • September 2023. scikit-learn 1.3.1 is available for download (Changelog). +
  • October 2023. scikit-learn 1.3.2 is available for download (Changelog). +
  • +
  • September 2023. scikit-learn 1.3.1 is available for download (Changelog).
  • June 2023. scikit-learn 1.3.0 is available for download (Changelog).
  • From cb7816d0147c0e290bb47fbe6d092f18a48b80b3 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 18 Oct 2023 16:05:39 +0200 Subject: [PATCH 08/12] MAINT remove prerelease flag for Python 3.12 (#27605) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Loïc Estève --- .github/workflows/wheels.yml | 6 ------ build_tools/cirrus/arm_wheel.yml | 4 ---- 2 files changed, 10 deletions(-) diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml index ae8f182450202..b82a114bff1af 100644 --- a/.github/workflows/wheels.yml +++ b/.github/workflows/wheels.yml @@ -70,8 +70,6 @@ jobs: - os: windows-latest python: 312 platform_id: win_amd64 - # TODO: remove when Python 3.12 is released - prerelease: "True" # Linux 64 bit manylinux2014 - os: ubuntu-latest @@ -97,8 +95,6 @@ jobs: python: 312 platform_id: manylinux_x86_64 manylinux_image: manylinux2014 - # TODO: remove when Python 3.12 is released - prerelease: "True" # MacOS x86_64 - os: macos-latest @@ -116,8 +112,6 @@ jobs: - os: macos-latest python: 312 platform_id: macosx_x86_64 - # TODO: remove when Python 3.12 is released - prerelease: "True" # MacOS arm64 # The wheel for the latest Python version is built and tested on diff --git a/build_tools/cirrus/arm_wheel.yml b/build_tools/cirrus/arm_wheel.yml index c5f5a34afa490..f210eea817601 100644 --- a/build_tools/cirrus/arm_wheel.yml +++ b/build_tools/cirrus/arm_wheel.yml @@ -25,8 +25,6 @@ macos_arm64_wheel_task: # is actually tested on Cirrus CI. - env: CIBW_BUILD: cp312-macosx_arm64 - # TODO: remove when Python 3.12 is released - CIBW_PRERELEASE_PYTHONS: True conda_script: - curl -L --retry 10 -o ~/mambaforge.sh https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh @@ -78,8 +76,6 @@ linux_arm64_wheel_task: CIBW_TEST_SKIP: "*_aarch64" - env: CIBW_BUILD: cp312-manylinux_aarch64 - # TODO: remove when Python 3.12 is released - CIBW_PRERELEASE_PYTHONS: True cibuildwheel_script: - apt install -y python3 python-is-python3 From cd1938ab5fd04eee267da4d0007de2980e5fdefc Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 18 Oct 2023 16:52:28 +0200 Subject: [PATCH 09/12] DOC move somes fixes from 1.4 to 1.3.2 (#27602) Co-authored-by: Olivier Grisel --- doc/whats_new/v1.3.rst | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/doc/whats_new/v1.3.rst b/doc/whats_new/v1.3.rst index d6ea7352d4f3f..70d14ab285eea 100644 --- a/doc/whats_new/v1.3.rst +++ b/doc/whats_new/v1.3.rst @@ -12,6 +12,30 @@ Version 1.3.2 Changelog --------- +:mod:`sklearn.datasets` +....................... + +- |Fix| All dataset fetchers now accept `data_home` as any object that implements + the :class:`os.PathLike` interface, for instance, :class:`pathlib.Path`. + :pr:`27468` by :user:`Yao Xiao `. + +:mod:`sklearn.decomposition` +............................ + +- |Fix| Fixes a bug in :class:`decomposition.KernelPCA` by forcing the output of + the internal :class:`preprocessing.KernelCenterer` to be a default array. When the + arpack solver is used, it expects an array with a `dtype` attribute. + :pr:`27583` by :user:`Guillaume Lemaitre `. + +:mod:`sklearn.metrics` +...................... + +- |Fix| Fixes a bug for metrics using `zero_division=np.nan` + (e.g. :func:`~metrics.precision_score`) within a paralell loop + (e.g. :func:`~model_selection.cross_val_score`) where the singleton for `np.nan` + will be different in the sub-processes. + :pr:`27573` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.tree` ................... From addca4bc99b5c5c9495e6fec22060406bde64dfe Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 18 Oct 2023 16:56:44 +0200 Subject: [PATCH 10/12] [cd build][azure parallel] trigger ci/cd builds From be2e067fa50c9b3f883799e87bd4420487e8e0fc Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Thu, 19 Oct 2023 12:35:45 +0200 Subject: [PATCH 11/12] TST change random seed to make graphical lasso test pass (#27616) --- sklearn/covariance/tests/test_graphical_lasso.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/sklearn/covariance/tests/test_graphical_lasso.py b/sklearn/covariance/tests/test_graphical_lasso.py index 317bf2aa85124..aaee1919e8dcc 100644 --- a/sklearn/covariance/tests/test_graphical_lasso.py +++ b/sklearn/covariance/tests/test_graphical_lasso.py @@ -24,7 +24,12 @@ ) -def test_graphical_lasso(random_state=0): +def test_graphical_lassos(random_state=1): + """Test the graphical lasso solvers. + + This checks is unstable for some random seeds where the covariance found with "cd" + and "lars" solvers are different (4 cases / 100 tries). + """ # Sample data from a sparse multivariate normal dim = 20 n_samples = 100 @@ -46,10 +51,11 @@ def test_graphical_lasso(random_state=0): costs, dual_gap = np.array(costs).T # Check that the costs always decrease (doesn't hold if alpha == 0) if not alpha == 0: - assert_array_less(np.diff(costs), 0) + # use 1e-12 since the cost can be exactly 0 + assert_array_less(np.diff(costs), 1e-12) # Check that the 2 approaches give similar results - assert_array_almost_equal(covs["cd"], covs["lars"], decimal=4) - assert_array_almost_equal(icovs["cd"], icovs["lars"], decimal=4) + assert_allclose(covs["cd"], covs["lars"], atol=1e-4) + assert_allclose(icovs["cd"], icovs["lars"], atol=1e-4) # Smoke test the estimator model = GraphicalLasso(alpha=0.25).fit(X) From fb2e5347cba15e48a1f69e052e0d3a1cec4b5727 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 23 Oct 2023 12:08:47 +0200 Subject: [PATCH 12/12] [cd build][azure parallel] trigger CI/CD