Skip to content

Commit 8bf3aa5

Browse files
glemaitrejeremiedbbogrisel
authored
FIX change the meaning of include_boundaries in check_scalar (scikit-learn#20921)
Co-authored-by: Jérémie du Boisberranger <34657725+jeremiedbb@users.noreply.github.com> Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 23cf99f commit 8bf3aa5

File tree

3 files changed

+46
-21
lines changed

3 files changed

+46
-21
lines changed

sklearn/cluster/_affinity_propagation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ class AffinityPropagation(ClusterMixin, BaseEstimator):
272272
Parameters
273273
----------
274274
damping : float, default=0.5
275-
Damping factor (between 0.5 and 1) is the extent to
275+
Damping factor in the range `[0.5, 1.0)` is the extent to
276276
which the current value is maintained relative to
277277
incoming values (weighted 1 - damping). This in order
278278
to avoid numerical oscillations when updating these
@@ -469,7 +469,7 @@ def fit(self, X, y=None):
469469
target_type=numbers.Real,
470470
min_val=0.5,
471471
max_val=1,
472-
closed="right",
472+
include_boundaries="left",
473473
)
474474
check_scalar(self.max_iter, "max_iter", target_type=numbers.Integral, min_val=1)
475475
check_scalar(

sklearn/utils/tests/test_validation.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,14 +1032,14 @@ def test_check_scalar_valid(x):
10321032
target_type=numbers.Real,
10331033
min_val=2,
10341034
max_val=5,
1035-
closed="neither",
1035+
include_boundaries="both",
10361036
)
10371037
assert len(record) == 0
10381038
assert scalar == x
10391039

10401040

10411041
@pytest.mark.parametrize(
1042-
"x, target_name, target_type, min_val, max_val, closed, err_msg",
1042+
"x, target_name, target_type, min_val, max_val, include_boundaries, err_msg",
10431043
[
10441044
(
10451045
1,
@@ -1059,7 +1059,7 @@ def test_check_scalar_valid(x):
10591059
2,
10601060
4,
10611061
"neither",
1062-
ValueError("test_name2 == 1, must be >= 2."),
1062+
ValueError("test_name2 == 1, must be > 2."),
10631063
),
10641064
(
10651065
5,
@@ -1068,15 +1068,15 @@ def test_check_scalar_valid(x):
10681068
2,
10691069
4,
10701070
"neither",
1071-
ValueError("test_name3 == 5, must be <= 4."),
1071+
ValueError("test_name3 == 5, must be < 4."),
10721072
),
10731073
(
10741074
2,
10751075
"test_name4",
10761076
int,
10771077
2,
10781078
4,
1079-
"left",
1079+
"right",
10801080
ValueError("test_name4 == 2, must be > 2."),
10811081
),
10821082
(
@@ -1085,13 +1085,25 @@ def test_check_scalar_valid(x):
10851085
int,
10861086
2,
10871087
4,
1088-
"right",
1088+
"left",
10891089
ValueError("test_name5 == 4, must be < 4."),
10901090
),
1091+
(
1092+
4,
1093+
"test_name6",
1094+
int,
1095+
2,
1096+
4,
1097+
"bad parameter value",
1098+
ValueError(
1099+
"Unknown value for `include_boundaries`: 'bad parameter value'. "
1100+
"Possible values are: ('left', 'right', 'both', 'neither')."
1101+
),
1102+
),
10911103
],
10921104
)
10931105
def test_check_scalar_invalid(
1094-
x, target_name, target_type, min_val, max_val, closed, err_msg
1106+
x, target_name, target_type, min_val, max_val, include_boundaries, err_msg
10951107
):
10961108
"""Test that check_scalar returns the right error if a wrong input is
10971109
given"""
@@ -1102,7 +1114,7 @@ def test_check_scalar_invalid(
11021114
target_type=target_type,
11031115
min_val=min_val,
11041116
max_val=max_val,
1105-
closed=closed,
1117+
include_boundaries=include_boundaries,
11061118
)
11071119
assert str(raised_error.value) == str(err_msg)
11081120
assert type(raised_error.value) == type(err_msg)

sklearn/utils/validation.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,7 +1242,7 @@ def check_scalar(
12421242
*,
12431243
min_val=None,
12441244
max_val=None,
1245-
closed="neither",
1245+
include_boundaries="both",
12461246
):
12471247
"""Validate scalar parameters type and value.
12481248
@@ -1265,9 +1265,15 @@ def check_scalar(
12651265
The maximum valid value the parameter can take. If None (default) it
12661266
is implied that the parameter does not have an upper bound.
12671267
1268-
closed : {"left", "right", "both", "neither"}, default="neither"
1269-
Whether the interval is closed on the left-side, right-side, both or
1270-
neither.
1268+
include_boundaries : {"left", "right", "both", "neither"}, default="both"
1269+
Whether the interval defined by `min_val` and `max_val` should include
1270+
the boundaries. Possible choices are:
1271+
1272+
- `"left"`: only `min_val` is included in the valid interval;
1273+
- `"right"`: only `max_val` is included in the valid interval;
1274+
- `"both"`: `min_val` and `max_val` are included in the valid interval;
1275+
- `"neither"`: neither `min_val` nor `max_val` are included in the
1276+
valid interval.
12711277
12721278
Returns
12731279
-------
@@ -1286,22 +1292,29 @@ def check_scalar(
12861292
if not isinstance(x, target_type):
12871293
raise TypeError(f"{name} must be an instance of {target_type}, not {type(x)}.")
12881294

1289-
expected_closed = {"left", "right", "both", "neither"}
1290-
if closed not in expected_closed:
1291-
raise ValueError(f"Unknown value for `closed`: {closed}")
1295+
expected_include_boundaries = ("left", "right", "both", "neither")
1296+
if include_boundaries not in expected_include_boundaries:
1297+
raise ValueError(
1298+
f"Unknown value for `include_boundaries`: {repr(include_boundaries)}. "
1299+
f"Possible values are: {expected_include_boundaries}."
1300+
)
12921301

1293-
comparison_operator = operator.le if closed in ("left", "both") else operator.lt
1302+
comparison_operator = (
1303+
operator.lt if include_boundaries in ("left", "both") else operator.le
1304+
)
12941305
if min_val is not None and comparison_operator(x, min_val):
12951306
raise ValueError(
12961307
f"{name} == {x}, must be"
1297-
f" {'>' if closed in ('left', 'both') else '>='} {min_val}."
1308+
f" {'>=' if include_boundaries in ('left', 'both') else '>'} {min_val}."
12981309
)
12991310

1300-
comparison_operator = operator.ge if closed in ("right", "both") else operator.gt
1311+
comparison_operator = (
1312+
operator.gt if include_boundaries in ("right", "both") else operator.ge
1313+
)
13011314
if max_val is not None and comparison_operator(x, max_val):
13021315
raise ValueError(
13031316
f"{name} == {x}, must be"
1304-
f" {'<' if closed in ('right', 'both') else '<='} {max_val}."
1317+
f" {'<=' if include_boundaries in ('right', 'both') else '<'} {max_val}."
13051318
)
13061319

13071320
return x

0 commit comments

Comments
 (0)