diff --git a/doc/whats_new/upcoming_changes/sklearn.datasets/31685.fix.rst b/doc/whats_new/upcoming_changes/sklearn.datasets/31685.fix.rst new file mode 100644 index 0000000000000..5d954e538d707 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.datasets/31685.fix.rst @@ -0,0 +1,5 @@ +- Fixed a regression preventing to extract the downloaded dataset in + :func:`datasets.fetch_20newsgroups`, :func:`datasets.fetch_20newsgroups_vectorized`, + :func:`datasets.fetch_lfw_people` and :func:`datasets.fetch_lfw_pairs`. This + only affects Python versions `>=3.10.0,<=3.10.11` and `>=3.11.0,<=3.11.3`. + By :user:`Jérémie du Boisberranger `. diff --git a/sklearn/datasets/_lfw.py b/sklearn/datasets/_lfw.py index 06420c41ed246..4f725b9250cc5 100644 --- a/sklearn/datasets/_lfw.py +++ b/sklearn/datasets/_lfw.py @@ -19,6 +19,7 @@ from ..utils import Bunch from ..utils._param_validation import Hidden, Interval, StrOptions, validate_params +from ..utils.fixes import tarfile_extractall from ._base import ( RemoteFileMetadata, _fetch_remote, @@ -117,10 +118,7 @@ def _check_fetch_lfw( logger.debug("Decompressing the data archive to %s", data_folder_path) with tarfile.open(archive_path, "r:gz") as fp: - # Use filter="data" to prevent the most dangerous security issues. - # For more details, see - # https://docs.python.org/3.9/library/tarfile.html#tarfile.TarFile.extractall - fp.extractall(path=lfw_home, filter="data") + tarfile_extractall(fp, path=lfw_home) remove(archive_path) diff --git a/sklearn/datasets/_twenty_newsgroups.py b/sklearn/datasets/_twenty_newsgroups.py index 62db8c5cbdc8e..1dc5fb6244f1b 100644 --- a/sklearn/datasets/_twenty_newsgroups.py +++ b/sklearn/datasets/_twenty_newsgroups.py @@ -43,6 +43,7 @@ from ..feature_extraction.text import CountVectorizer from ..utils import Bunch, check_random_state from ..utils._param_validation import Interval, StrOptions, validate_params +from ..utils.fixes import tarfile_extractall from . import get_data_home, load_files from ._base import ( RemoteFileMetadata, @@ -81,10 +82,7 @@ def _download_20newsgroups(target_dir, cache_path, n_retries, delay): logger.debug("Decompressing %s", archive_path) with tarfile.open(archive_path, "r:gz") as fp: - # Use filter="data" to prevent the most dangerous security issues. - # For more details, see - # https://docs.python.org/3.9/library/tarfile.html#tarfile.TarFile.extractall - fp.extractall(path=target_dir, filter="data") + tarfile_extractall(fp, path=target_dir) with suppress(FileNotFoundError): os.remove(archive_path) diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index 5ceb9930b993b..29c847d3aa34c 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -361,6 +361,17 @@ def _smallest_admissible_index_dtype(arrays=(), maxval=None, check_contents=Fals ) +# TODO: Remove when Python min version >= 3.12. +def tarfile_extractall(tarfile, path): + try: + # Use filter="data" to prevent the most dangerous security issues. + # For more details, see + # https://docs.python.org/3/library/tarfile.html#tarfile.TarFile.extractall + tarfile.extractall(path, filter="data") + except TypeError: + tarfile.extractall(path) + + def _in_unstable_openblas_configuration(): """Return True if in an unstable configuration for OpenBLAS"""