Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.base/31528.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
- Fix regression in HTML representation when detecting the non-default parameters
that where of array-like types.
By :user:`Dea María Léon <deamarialeon>`
6 changes: 4 additions & 2 deletions sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,14 @@ def is_non_default(param_name, param_value):
init_default_params[param_name]
):
return True

if param_value != init_default_params[param_name] and not (
if not np.array_equal(
param_value, init_default_params[param_name]
) and not (
is_scalar_nan(init_default_params[param_name])
and is_scalar_nan(param_value)
):
return True

return False

# reorder the parameters from `self.get_params` using the `__init__`
Expand Down
81 changes: 80 additions & 1 deletion sklearn/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from sklearn.decomposition import PCA
from sklearn.ensemble import IsolationForest
from sklearn.exceptions import InconsistentVersionWarning
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import get_scorer
from sklearn.model_selection import GridSearchCV, KFold
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC, SVR
Expand Down Expand Up @@ -1000,3 +1001,81 @@ def test_get_params_html():

assert est._get_params_html() == {"l1": 0, "empty": "test"}
assert est._get_params_html().non_default == ("empty",)


def make_estimator_with_param(default_value):
class DynamicEstimator(BaseEstimator):
def __init__(self, param=default_value):
self.param = param

return DynamicEstimator


@pytest.mark.parametrize(
"default_value, test_value",
[
((), (1,)),
((), [1]),
((), np.array([1])),
((1, 2), (3, 4)),
((1, 2), [3, 4]),
((1, 2), np.array([3, 4])),
(None, 1),
(None, []),
(None, lambda x: x),
(np.nan, 1.0),
(np.nan, np.array([np.nan])),
("abc", "def"),
("abc", ["abc"]),
(True, False),
(1, 2),
(1, [1]),
(1, np.array([1])),
(1.0, 2.0),
(1.0, [1.0]),
(1.0, np.array([1.0])),
([1, 2], [3]),
(np.array([1]), [2, 3]),
(None, KFold()),
(None, get_scorer("accuracy")),
],
)
def test_param_is_non_default(default_value, test_value):
"""Check that we detect non-default parameters with various types.

Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/31525
"""
estimator = make_estimator_with_param(default_value)(param=test_value)
non_default = estimator._get_params_html().non_default
assert "param" in non_default


@pytest.mark.parametrize(
"default_value, test_value",
[
(None, None),
((), ()),
((), []),
((), np.array([])),
((1, 2, 3), (1, 2, 3)),
((1, 2, 3), [1, 2, 3]),
((1, 2, 3), np.array([1, 2, 3])),
(np.nan, np.nan),
("abc", "abc"),
(True, True),
(1, 1),
(1.0, 1.0),
(2, 2.0),
],
)
def test_param_is_default(default_value, test_value):
"""Check that we detect the default parameters and values in an array-like will
be reported as default as well.

Non-regression test for:
https://github.com/scikit-learn/scikit-learn/issues/31525
"""
estimator = make_estimator_with_param(default_value)(param=test_value)
non_default = estimator._get_params_html().non_default
assert "param" not in non_default