From 23403291e53533def8b82b48039dbb091f4c2742 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 27 Dec 2022 17:40:48 +0100 Subject: [PATCH 01/15] FIX get config from dispatcher thread in delayed by default --- sklearn/_config.py | 27 +++++++++++++++++++++------ sklearn/utils/fixes.py | 18 +++++++++++------- 2 files changed, 32 insertions(+), 13 deletions(-) diff --git a/sklearn/_config.py b/sklearn/_config.py index e4c398c9c5444..b59922d366acc 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -4,7 +4,7 @@ from contextlib import contextmanager as contextmanager import threading -_global_config = { +_global_config_default = { "assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)), "working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)), "print_changed_only": True, @@ -17,19 +17,31 @@ "transform_output": "default", } _threadlocal = threading.local() +_thread_config = {} def _get_threadlocal_config(): - """Get a threadlocal **mutable** configuration. If the configuration - does not exist, copy the default global configuration.""" + """Get a threadlocal **mutable** configuration. + + If the configuration does not exist, copy the default global configuration. + The configuration is also registered to in a global dictionary where the + key is the thread id. + """ if not hasattr(_threadlocal, "global_config"): - _threadlocal.global_config = _global_config.copy() + _threadlocal.global_config = _global_config_default.copy() + _thread_config[threading.get_ident()] = _threadlocal.global_config return _threadlocal.global_config -def get_config(): +def get_config(thread_id=None): """Retrieve current values for configuration set by :func:`set_config`. + Parameters + ---------- + thread_id : int, default=None + The thread id from which to retrieve the configuration. If `None`, + the current thread id is used. + Returns ------- config : dict @@ -42,7 +54,10 @@ def get_config(): """ # Return a copy of the threadlocal configuration so that users will # not be able to modify the configuration with the returned dict. - return _get_threadlocal_config().copy() + if thread_id is None: + return _get_threadlocal_config().copy() + _get_threadlocal_config() + return _thread_config[thread_id].copy() def set_config( diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index b43afa950a163..811f4c30da72c 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -14,6 +14,7 @@ from importlib import resources import functools import sys +import threading import sklearn import numpy as np @@ -107,22 +108,25 @@ def _eigh(*args, **kwargs): # remove when https://github.com/joblib/joblib/issues/1071 is fixed -def delayed(function): +def delayed(func, thread_id=threading.get_ident()): """Decorator used to capture the arguments of a function.""" - @functools.wraps(function) - def delayed_function(*args, **kwargs): - return _FuncWrapper(function), args, kwargs + def decorate(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + return _FuncWrapper(func, thread_id=thread_id), args, kwargs - return delayed_function + return wrapper + + return decorate(func) class _FuncWrapper: """ "Load the global configuration before calling the function.""" - def __init__(self, function): + def __init__(self, function, thread_id=None): self.function = function - self.config = get_config() + self.config = get_config(thread_id=thread_id) update_wrapper(self, self.function) def __call__(self, *args, **kwargs): From c7efddaef9f2a3b65d31e638515c81cad4909336 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 27 Dec 2022 17:57:27 +0100 Subject: [PATCH 02/15] use weakref --- sklearn/_config.py | 13 +++++++------ sklearn/utils/fixes.py | 8 ++++---- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/sklearn/_config.py b/sklearn/_config.py index b59922d366acc..07b471b16b9b8 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -3,6 +3,7 @@ import os from contextlib import contextmanager as contextmanager import threading +import weakref _global_config_default = { "assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)), @@ -17,7 +18,7 @@ "transform_output": "default", } _threadlocal = threading.local() -_thread_config = {} +_thread_config = weakref.WeakKeyDictionary() # type: ignore def _get_threadlocal_config(): @@ -29,11 +30,11 @@ def _get_threadlocal_config(): """ if not hasattr(_threadlocal, "global_config"): _threadlocal.global_config = _global_config_default.copy() - _thread_config[threading.get_ident()] = _threadlocal.global_config + _thread_config[threading.current_thread()] = _threadlocal.global_config return _threadlocal.global_config -def get_config(thread_id=None): +def get_config(thread=None): """Retrieve current values for configuration set by :func:`set_config`. Parameters @@ -54,10 +55,10 @@ def get_config(thread_id=None): """ # Return a copy of the threadlocal configuration so that users will # not be able to modify the configuration with the returned dict. - if thread_id is None: + if thread is None: return _get_threadlocal_config().copy() - _get_threadlocal_config() - return _thread_config[thread_id].copy() + _get_threadlocal_config() # register the config to the thread if does not exist + return _thread_config[thread].copy() def set_config( diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index 811f4c30da72c..f7cd5189262fb 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -108,13 +108,13 @@ def _eigh(*args, **kwargs): # remove when https://github.com/joblib/joblib/issues/1071 is fixed -def delayed(func, thread_id=threading.get_ident()): +def delayed(func, thread=threading.current_thread()): """Decorator used to capture the arguments of a function.""" def decorate(func): @functools.wraps(func) def wrapper(*args, **kwargs): - return _FuncWrapper(func, thread_id=thread_id), args, kwargs + return _FuncWrapper(func, thread=thread), args, kwargs return wrapper @@ -124,9 +124,9 @@ def wrapper(*args, **kwargs): class _FuncWrapper: """ "Load the global configuration before calling the function.""" - def __init__(self, function, thread_id=None): + def __init__(self, function, thread): self.function = function - self.config = get_config(thread_id=thread_id) + self.config = get_config(thread=thread) update_wrapper(self, self.function) def __call__(self, *args, **kwargs): From 2d4e553e1d0054672fe927d6d8cb28bb8f95da7b Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 27 Dec 2022 18:00:43 +0100 Subject: [PATCH 03/15] doc stuff --- sklearn/_config.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sklearn/_config.py b/sklearn/_config.py index 07b471b16b9b8..0222b21d7f8ca 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -25,8 +25,8 @@ def _get_threadlocal_config(): """Get a threadlocal **mutable** configuration. If the configuration does not exist, copy the default global configuration. - The configuration is also registered to in a global dictionary where the - key is the thread id. + The configuration is also registered to a global dictionary where the keys + are weak references to the thread objects. """ if not hasattr(_threadlocal, "global_config"): _threadlocal.global_config = _global_config_default.copy() @@ -39,9 +39,9 @@ def get_config(thread=None): Parameters ---------- - thread_id : int, default=None - The thread id from which to retrieve the configuration. If `None`, - the current thread id is used. + thread : Thread, default=None + The thread for which to retrieve the configuration. If None, the + configuration of the current thread is returned. Returns ------- @@ -55,9 +55,9 @@ def get_config(thread=None): """ # Return a copy of the threadlocal configuration so that users will # not be able to modify the configuration with the returned dict. + threadlocal_config = _get_threadlocal_config() if thread is None: - return _get_threadlocal_config().copy() - _get_threadlocal_config() # register the config to the thread if does not exist + return threadlocal_config.copy() return _thread_config[thread].copy() From 3c03261112973ca056e1251e80f135a7b5dc2566 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 27 Dec 2022 18:18:27 +0100 Subject: [PATCH 04/15] TST add test for delayed --- sklearn/utils/tests/test_fixes.py | 41 +++++++++++++++++++++++++++++-- 1 file changed, 39 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_fixes.py b/sklearn/utils/tests/test_fixes.py index 3566897da5efc..c4c579257086b 100644 --- a/sklearn/utils/tests/test_fixes.py +++ b/sklearn/utils/tests/test_fixes.py @@ -4,15 +4,18 @@ # License: BSD 3 clause import math +import threading import numpy as np import pytest import scipy.stats +from joblib import Parallel + +import sklearn from sklearn.utils._testing import assert_array_equal -from sklearn.utils.fixes import _object_dtype_isnan -from sklearn.utils.fixes import loguniform +from sklearn.utils.fixes import _object_dtype_isnan, delayed, loguniform @pytest.mark.parametrize("dtype, val", ([object, 1], [object, "a"], [float, 1])) @@ -46,3 +49,37 @@ def test_loguniform(low, high, base): assert loguniform(base**low, base**high).rvs(random_state=0) == loguniform( base**low, base**high ).rvs(random_state=0) + + +def test_delayed_fetching_right_config(): + """Check that `delayed` function fetches the right config associated to + the main thread. + + Non-regression test for: + https://github.com/scikit-learn/scikit-learn/issues/25239 + """ + + def get_working_memory(): + return sklearn.get_config()["working_memory"] + + n_iter = 10 + + # by default, we register the main thread and we should retrieve the + # parameters defined within the context manager + with sklearn.config_context(working_memory=123): + results = Parallel(n_jobs=2, pre_dispatch=4)( + delayed(get_working_memory)() for _ in range(n_iter) + ) + + assert results == [123] * n_iter + + # simulate that we refer to another thread + local_thread = threading.Thread(target=sklearn.get_config) + local_thread.start() + local_thread.join() + with sklearn.config_context(working_memory=123): + results = Parallel(n_jobs=2, pre_dispatch=4)( + delayed(get_working_memory, thread=local_thread)() for _ in range(n_iter) + ) + + assert results == [get_working_memory()] * n_iter From cd38957bd198b92286f79baa8384b9f47c648697 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 28 Dec 2022 10:31:51 +0100 Subject: [PATCH 05/15] TST add more test in config --- sklearn/tests/test_config.py | 43 ++++++++++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index a0b8f29662b69..e84ffbec54d8e 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -1,4 +1,5 @@ import time +import threading from concurrent.futures import ThreadPoolExecutor from joblib import Parallel @@ -120,15 +121,15 @@ def test_config_threadsafe_joblib(backend): should be the same as the value passed to the function. In other words, it is not influenced by the other job setting assume_finite to True. """ - assume_finites = [False, True] - sleep_durations = [0.1, 0.2] + assume_finites = [False, True, False, True] + sleep_durations = [0.1, 0.2, 0.1, 0.2] - items = Parallel(backend=backend, n_jobs=2)( + items = Parallel(backend=backend, n_jobs=2, pre_dispatch=2)( delayed(set_assume_finite)(assume_finite, sleep_dur) for assume_finite, sleep_dur in zip(assume_finites, sleep_durations) ) - assert items == [False, True] + assert items == [False, True, False, True] def test_config_threadsafe(): @@ -136,8 +137,8 @@ def test_config_threadsafe(): between threads. Same test as `test_config_threadsafe_joblib` but with `ThreadPoolExecutor`.""" - assume_finites = [False, True] - sleep_durations = [0.1, 0.2] + assume_finites = [False, True, False, True] + sleep_durations = [0.1, 0.2, 0.1, 0.2] with ThreadPoolExecutor(max_workers=2) as e: items = [ @@ -145,4 +146,32 @@ def test_config_threadsafe(): for output in e.map(set_assume_finite, assume_finites, sleep_durations) ] - assert items == [False, True] + assert items == [False, True, False, True] + + +def test_get_config_thread_dependent(): + """Check that we can retrieve the config file from a specific thread.""" + + def set_definitive_assume_finite(assume_finite, sleep_duration): + set_config(assume_finite=assume_finite) + time.sleep(sleep_duration) + return get_config()["assume_finite"] + + thread = threading.Thread(target=set_definitive_assume_finite, args=(True, 0.1)) + thread.start() + thread.join() + + thread_specific_config = get_config(thread=thread) + assert thread_specific_config["assume_finite"] is True + main_thread_config = get_config() + assert main_thread_config["assume_finite"] is False + + # check that we have 2 threads registered in the thread config dictionary + from sklearn._config import _thread_config + + assert len(_thread_config) == 2 + + # delete the thread and check that the dictionary does keep a reference to it + # since we use a weakref dictionary + del thread + assert len(_thread_config) == 1 From 37ce3efb456e082b165b54e8854677612e02a382 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 28 Dec 2022 10:39:39 +0100 Subject: [PATCH 06/15] DOC add a changelog entry --- doc/whats_new/v1.2.rst | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index 6442979db402e..fdaa22798ae9c 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -9,9 +9,24 @@ Version 1.2.1 **In Development** +Changes impacting all modules +----------------------------- + +- |Fix| Fix a bug that was ignoring the global configuration in estimators using + `n_jobs > 1`. This bug was triggered for the tasks not dispatch by the main + thread in `joblib` since :func:`sklearn.get_config` uses thread local configuration. + :pr:`25242` by :user:`Guillaume Lemaitre `. + Changelog --------- +:mod:`sklearn` +.............. + +- |Enhancement| :func:`sklearn.get_config` takes a parameter `thread` allowing to + retrieve the local configuration of this specific `thread`. + :pr:`25242` by :user:`Guillaume Lemaitre `. + :mod:`sklearn.base` ................... From 011a7ecd935b7deb7dd534204a268c11fdd8933a Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 28 Dec 2022 10:52:06 +0100 Subject: [PATCH 07/15] be nice with mypy --- sklearn/_config.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sklearn/_config.py b/sklearn/_config.py index 0222b21d7f8ca..1e0773501a5e5 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -1,9 +1,11 @@ """Global configuration state and functions for management """ import os -from contextlib import contextmanager as contextmanager import threading -import weakref + +from contextlib import contextmanager as contextmanager +from typing import Dict +from weakref import WeakKeyDictionary _global_config_default = { "assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)), @@ -18,7 +20,7 @@ "transform_output": "default", } _threadlocal = threading.local() -_thread_config = weakref.WeakKeyDictionary() # type: ignore +_thread_config = WeakKeyDictionary() # type: WeakKeyDictionary[threading.Thread, Dict] def _get_threadlocal_config(): From 67a9f186ffcee1740602ba4ed68f78db9d37ec83 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 28 Dec 2022 10:57:55 +0100 Subject: [PATCH 08/15] avoid flake8 complain --- sklearn/_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/_config.py b/sklearn/_config.py index 1e0773501a5e5..6764b739a5d46 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -4,7 +4,7 @@ import threading from contextlib import contextmanager as contextmanager -from typing import Dict +from typing import Dict # noqa from weakref import WeakKeyDictionary _global_config_default = { From e139fa71d771e2c27c4216c7df7874bc3494c565 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 28 Dec 2022 11:25:00 +0100 Subject: [PATCH 09/15] remove test that can fail with parallel execution --- sklearn/tests/test_config.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index e84ffbec54d8e..389512c742e30 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -165,13 +165,3 @@ def set_definitive_assume_finite(assume_finite, sleep_duration): assert thread_specific_config["assume_finite"] is True main_thread_config = get_config() assert main_thread_config["assume_finite"] is False - - # check that we have 2 threads registered in the thread config dictionary - from sklearn._config import _thread_config - - assert len(_thread_config) == 2 - - # delete the thread and check that the dictionary does keep a reference to it - # since we use a weakref dictionary - del thread - assert len(_thread_config) == 1 From cf16195a53a42cc2dc72d9e5bc67127ce85f1db1 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 3 Jan 2023 09:52:18 +0100 Subject: [PATCH 10/15] add thread argument to private function --- sklearn/utils/fixes.py | 7 ++++++- sklearn/utils/tests/test_fixes.py | 8 ++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index f7cd5189262fb..69a18476a412e 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -108,8 +108,13 @@ def _eigh(*args, **kwargs): # remove when https://github.com/joblib/joblib/issues/1071 is fixed -def delayed(func, thread=threading.current_thread()): +def delayed(func): """Decorator used to capture the arguments of a function.""" + return _delayed(func) + + +def _delayed(func, thread=threading.current_thread()): + """Private function to expose the thread argument.""" def decorate(func): @functools.wraps(func) diff --git a/sklearn/utils/tests/test_fixes.py b/sklearn/utils/tests/test_fixes.py index c4c579257086b..ccf8c94a5d35d 100644 --- a/sklearn/utils/tests/test_fixes.py +++ b/sklearn/utils/tests/test_fixes.py @@ -15,7 +15,7 @@ import sklearn from sklearn.utils._testing import assert_array_equal -from sklearn.utils.fixes import _object_dtype_isnan, delayed, loguniform +from sklearn.utils.fixes import _delayed, _object_dtype_isnan, loguniform @pytest.mark.parametrize("dtype, val", ([object, 1], [object, "a"], [float, 1])) @@ -52,7 +52,7 @@ def test_loguniform(low, high, base): def test_delayed_fetching_right_config(): - """Check that `delayed` function fetches the right config associated to + """Check that `_delayed` function fetches the right config associated to the main thread. Non-regression test for: @@ -68,7 +68,7 @@ def get_working_memory(): # parameters defined within the context manager with sklearn.config_context(working_memory=123): results = Parallel(n_jobs=2, pre_dispatch=4)( - delayed(get_working_memory)() for _ in range(n_iter) + _delayed(get_working_memory)() for _ in range(n_iter) ) assert results == [123] * n_iter @@ -79,7 +79,7 @@ def get_working_memory(): local_thread.join() with sklearn.config_context(working_memory=123): results = Parallel(n_jobs=2, pre_dispatch=4)( - delayed(get_working_memory, thread=local_thread)() for _ in range(n_iter) + _delayed(get_working_memory, thread=local_thread)() for _ in range(n_iter) ) assert results == [get_working_memory()] * n_iter From 52bf297596edd5d2df5e7d3e6ba5cb9e543f9740 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 3 Jan 2023 09:59:26 +0100 Subject: [PATCH 11/15] MAINT remove threadlocal --- sklearn/_config.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/sklearn/_config.py b/sklearn/_config.py index 6764b739a5d46..4841b20e2e812 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -19,21 +19,22 @@ "array_api_dispatch": False, "transform_output": "default", } -_threadlocal = threading.local() _thread_config = WeakKeyDictionary() # type: WeakKeyDictionary[threading.Thread, Dict] -def _get_threadlocal_config(): - """Get a threadlocal **mutable** configuration. +def _get_thread_config(thread=None): + """Get a thread **mutable** configuration. If the configuration does not exist, copy the default global configuration. The configuration is also registered to a global dictionary where the keys are weak references to the thread objects. """ - if not hasattr(_threadlocal, "global_config"): - _threadlocal.global_config = _global_config_default.copy() - _thread_config[threading.current_thread()] = _threadlocal.global_config - return _threadlocal.global_config + if thread is None: + thread = threading.current_thread() + + if thread not in _thread_config: + _thread_config[thread] = _global_config_default.copy() + return _thread_config[thread] def get_config(thread=None): @@ -55,12 +56,9 @@ def get_config(thread=None): config_context : Context manager for global scikit-learn configuration. set_config : Set global scikit-learn configuration. """ - # Return a copy of the threadlocal configuration so that users will - # not be able to modify the configuration with the returned dict. - threadlocal_config = _get_threadlocal_config() - if thread is None: - return threadlocal_config.copy() - return _thread_config[thread].copy() + # Return a copy of the configuration so that users will not be able to + # modify the configuration with the returned dict. + return _get_thread_config(thread=thread).copy() def set_config( @@ -157,7 +155,7 @@ def set_config( config_context : Context manager for global scikit-learn configuration. get_config : Retrieve current values of the global configuration. """ - local_config = _get_threadlocal_config() + local_config = _get_thread_config() if assume_finite is not None: local_config["assume_finite"] = assume_finite From 566560300fd8ad80c76a6e4b5b99472219ab32c7 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 3 Jan 2023 10:06:24 +0100 Subject: [PATCH 12/15] do not modify the API of get_config --- sklearn/_config.py | 10 ++-------- sklearn/tests/test_config.py | 10 ++++++---- sklearn/utils/fixes.py | 4 ++-- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/sklearn/_config.py b/sklearn/_config.py index 4841b20e2e812..2069a22b0ff76 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -37,15 +37,9 @@ def _get_thread_config(thread=None): return _thread_config[thread] -def get_config(thread=None): +def get_config(): """Retrieve current values for configuration set by :func:`set_config`. - Parameters - ---------- - thread : Thread, default=None - The thread for which to retrieve the configuration. If None, the - configuration of the current thread is returned. - Returns ------- config : dict @@ -58,7 +52,7 @@ def get_config(thread=None): """ # Return a copy of the configuration so that users will not be able to # modify the configuration with the returned dict. - return _get_thread_config(thread=thread).copy() + return _get_thread_config().copy() def set_config( diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index 389512c742e30..ad7d41f2f81db 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -5,6 +5,8 @@ from joblib import Parallel import pytest +from sklearn._config import _get_thread_config + from sklearn import get_config, set_config, config_context from sklearn.utils.fixes import delayed @@ -149,19 +151,19 @@ def test_config_threadsafe(): assert items == [False, True, False, True] -def test_get_config_thread_dependent(): +def test_get_thread_config(): """Check that we can retrieve the config file from a specific thread.""" def set_definitive_assume_finite(assume_finite, sleep_duration): set_config(assume_finite=assume_finite) time.sleep(sleep_duration) - return get_config()["assume_finite"] + return _get_thread_config()["assume_finite"] thread = threading.Thread(target=set_definitive_assume_finite, args=(True, 0.1)) thread.start() thread.join() - thread_specific_config = get_config(thread=thread) + thread_specific_config = _get_thread_config(thread=thread) assert thread_specific_config["assume_finite"] is True - main_thread_config = get_config() + main_thread_config = _get_thread_config() assert main_thread_config["assume_finite"] is False diff --git a/sklearn/utils/fixes.py b/sklearn/utils/fixes.py index 69a18476a412e..285b306f3fe74 100644 --- a/sklearn/utils/fixes.py +++ b/sklearn/utils/fixes.py @@ -21,7 +21,7 @@ import scipy import scipy.stats import threadpoolctl -from .._config import config_context, get_config +from .._config import config_context, _get_thread_config from ..externals._packaging.version import parse as parse_version @@ -131,7 +131,7 @@ class _FuncWrapper: def __init__(self, function, thread): self.function = function - self.config = get_config(thread=thread) + self.config = _get_thread_config(thread=thread) update_wrapper(self, self.function) def __call__(self, *args, **kwargs): From e42793f293201db3baf51553143880bd0050c343 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 4 Jan 2023 11:52:00 +0100 Subject: [PATCH 13/15] Apply suggestions from code review Co-authored-by: Thomas J. Fan --- sklearn/_config.py | 2 +- sklearn/tests/test_config.py | 9 ++------- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/sklearn/_config.py b/sklearn/_config.py index 2069a22b0ff76..0f5f51ff20bc0 100644 --- a/sklearn/_config.py +++ b/sklearn/_config.py @@ -3,7 +3,7 @@ import os import threading -from contextlib import contextmanager as contextmanager +from contextlib import contextmanager from typing import Dict # noqa from weakref import WeakKeyDictionary diff --git a/sklearn/tests/test_config.py b/sklearn/tests/test_config.py index ad7d41f2f81db..e9248bcac34b1 100644 --- a/sklearn/tests/test_config.py +++ b/sklearn/tests/test_config.py @@ -152,14 +152,9 @@ def test_config_threadsafe(): def test_get_thread_config(): - """Check that we can retrieve the config file from a specific thread.""" + """Check that we can retrieve the config from a specific thread.""" - def set_definitive_assume_finite(assume_finite, sleep_duration): - set_config(assume_finite=assume_finite) - time.sleep(sleep_duration) - return _get_thread_config()["assume_finite"] - - thread = threading.Thread(target=set_definitive_assume_finite, args=(True, 0.1)) + thread = threading.Thread(target=set_config, kwargs={"assume_finite": True}) thread.start() thread.join() From bb2b40dcd93e1ef2ad9e50f35d59dbce0f02f5fe Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 4 Jan 2023 11:52:27 +0100 Subject: [PATCH 14/15] DOC update changelog --- doc/whats_new/v1.2.rst | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index ccfc325c28fef..f67852a8c0b8c 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -13,20 +13,13 @@ Changes impacting all modules ----------------------------- - |Fix| Fix a bug that was ignoring the global configuration in estimators using - `n_jobs > 1`. This bug was triggered for the tasks not dispatch by the main + `n_jobs > 1`. This bug was triggered for the tasks not dispatched by the main thread in `joblib` since :func:`sklearn.get_config` uses thread local configuration. :pr:`25242` by :user:`Guillaume Lemaitre `. Changelog --------- -:mod:`sklearn` -.............. - -- |Enhancement| :func:`sklearn.get_config` takes a parameter `thread` allowing to - retrieve the local configuration of this specific `thread`. - :pr:`25242` by :user:`Guillaume Lemaitre `. - :mod:`sklearn.base` ................... From 379249e921048bc030163ee78a0ce64bc7b433be Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Wed, 4 Jan 2023 13:44:23 +0100 Subject: [PATCH 15/15] Update doc/whats_new/v1.2.rst Co-authored-by: Olivier Grisel --- doc/whats_new/v1.2.rst | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/doc/whats_new/v1.2.rst b/doc/whats_new/v1.2.rst index f67852a8c0b8c..ec91071681af2 100644 --- a/doc/whats_new/v1.2.rst +++ b/doc/whats_new/v1.2.rst @@ -12,9 +12,11 @@ Version 1.2.1 Changes impacting all modules ----------------------------- -- |Fix| Fix a bug that was ignoring the global configuration in estimators using - `n_jobs > 1`. This bug was triggered for the tasks not dispatched by the main - thread in `joblib` since :func:`sklearn.get_config` uses thread local configuration. +- |Fix| Fix a bug where the current configuration was ignored in estimators using + `n_jobs > 1`. This bug was triggered for tasks dispatched by the ancillary + thread of `joblib` as :func:`sklearn.get_config` used to access an empty thread + local configuration instead of the configuration visible from the thread where + `delayed` was first called. :pr:`25242` by :user:`Guillaume Lemaitre `. Changelog