-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
FIX get config from dispatcher thread in delayed by default #25242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2340329
c7efdda
2d4e553
3c03261
cd38957
37ce3ef
011a7ec
67a9f18
e139fa7
3479ca9
cf16195
52bf297
5665603
e42793f
bb2b40d
379249e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,15 +4,18 @@ | |
# License: BSD 3 clause | ||
|
||
import math | ||
import threading | ||
|
||
import numpy as np | ||
import pytest | ||
import scipy.stats | ||
|
||
from joblib import Parallel | ||
|
||
import sklearn | ||
from sklearn.utils._testing import assert_array_equal | ||
|
||
from sklearn.utils.fixes import _object_dtype_isnan | ||
from sklearn.utils.fixes import loguniform | ||
from sklearn.utils.fixes import _delayed, _object_dtype_isnan, loguniform | ||
|
||
|
||
@pytest.mark.parametrize("dtype, val", ([object, 1], [object, "a"], [float, 1])) | ||
|
@@ -46,3 +49,37 @@ def test_loguniform(low, high, base): | |
assert loguniform(base**low, base**high).rvs(random_state=0) == loguniform( | ||
base**low, base**high | ||
).rvs(random_state=0) | ||
|
||
|
||
def test_delayed_fetching_right_config(): | ||
"""Check that `_delayed` function fetches the right config associated to | ||
the main thread. | ||
|
||
Non-regression test for: | ||
https://github.com/scikit-learn/scikit-learn/issues/25239 | ||
""" | ||
|
||
def get_working_memory(): | ||
return sklearn.get_config()["working_memory"] | ||
|
||
n_iter = 10 | ||
|
||
# by default, we register the main thread and we should retrieve the | ||
# parameters defined within the context manager | ||
with sklearn.config_context(working_memory=123): | ||
results = Parallel(n_jobs=2, pre_dispatch=4)( | ||
_delayed(get_working_memory)() for _ in range(n_iter) | ||
) | ||
|
||
assert results == [123] * n_iter | ||
|
||
# simulate that we refer to another thread | ||
local_thread = threading.Thread(target=sklearn.get_config) | ||
local_thread.start() | ||
local_thread.join() | ||
with sklearn.config_context(working_memory=123): | ||
results = Parallel(n_jobs=2, pre_dispatch=4)( | ||
_delayed(get_working_memory, thread=local_thread)() for _ in range(n_iter) | ||
) | ||
|
||
assert results == [get_working_memory()] * n_iter | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the default sklearn.set_config(working_memory=140)
# the following fails
assert results == [get_working_memory()] * n_iter The less fragile assertion would be check that the from sklearn._config import _global_config_default
assert results == [_global_config_default["working_memory"]] * n_iter There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reviewing this test, it seems that I have found a bug in the way the default value of the thread argument of delayed is defined. I am working on a PR against this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Capturing the current thread here is problematic because it makes the behavior of scikit-learn dependent on which thread first imported scikit-learn and scikit-learn's behavior is no longer thread-symmetric.
I tried changing this to:
but this does not work either. Thread inspection does not work as intended (too late) when calling
Parallel
on a generator expression which is the canonical way to use joblib. Instead we should capture the state of the config of the thread that callsParallel
just before the call happens and ship it to all the dispatched tasks.We just had a live pair-debugging / programming session with @glemaitre on discord and I think we came up with a better solution that is a bit more verbose but also much more explicit (and correct ;). He will open a PR soon and we will be able to have a more informed technical discussion there.
For the longer term we could expose a hook in joblib to better handle this kind of configuration propagation but having a stopgap fix in scikit-learn makes it possible to decouple scikit-learn from the version of joblib.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refer to #25290 for the better solution
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 I was afraid it would come to this kind of more verbose solution. Maybe at the same time this is merge to enable the fix, a separate issue could be opened to discuss the in and outs of the per-thread config ? unless the behavior that is enforced and supported is clear already but that didn't seem to be (the PR where the behavior was enabled does not discuss much)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is true that we did not discuss this here but we did during the pair programming session.
@ogrisel agreed that we should keep the current behavior where you don't want a thread modifying the config during that other threads may use it. It is a bit counter-intuitive if we rely on the fact that threads should share memory but the side-effect within scikit-learn would be potentially bad. For instance, you can potentially get different random errors that is not reproducible because it would depend on the config state at a particular moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should discuss the two options at the next dev meeting:
I think I am in favor of option 2 but I think it worth discussing this with others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Option 1 would resolve the issue for multi-threading, but I think the issue will remain for multiprocessing or loky.I am okay with Option 2. Most of my concern is how third party developers using
joblib
need to update their code to useutils.fixes.delayed
to work with scikit-learn's config.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the developer call,
joblib
uses another thread to get jobs from a generator, which means Option 1 with a thread local configuration would resolve the current issue.