Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sklearn/model_selection/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def _yield_masked_array_for_each_param(candidate_params):

# Use one MaskedArray and mask all the places where the param is not
# applicable for that candidate (which may not contain all the params).
ma = MaskedArray(np.empty(n_candidates), mask=True, dtype=arr_dtype)
ma = MaskedArray(np.empty(n_candidates, dtype=arr_dtype), mask=True)
for index, value in param_result.items():
# Setting the value at an index unmasks that index
ma[index] = value
Expand Down
8 changes: 8 additions & 0 deletions sklearn/model_selection/tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2864,3 +2864,11 @@ def test_yield_masked_array_for_each_param(candidate_params, expected):
assert value.dtype == expected_value.dtype
np.testing.assert_array_equal(value, expected_value)
np.testing.assert_array_equal(value.mask, expected_value.mask)


def test_yield_masked_array_no_runtime_warning():
# non-regression test for https://github.com/scikit-learn/scikit-learn/issues/29929
candidate_params = [{"param": i} for i in range(1000)]
with warnings.catch_warnings():
warnings.simplefilter("error", RuntimeWarning)
list(_yield_masked_array_for_each_param(candidate_params))
Loading