Skip to content

Commit 38a1e64

Browse files
naoise-hjeremiedbb
andauthored
FIX Param validation Interval error for large integers (#26648)
Co-authored-by: jeremie du boisberranger <jeremiedbb@yahoo.fr>
1 parent 32f8bda commit 38a1e64

File tree

4 files changed

+43
-7
lines changed

4 files changed

+43
-7
lines changed

doc/whats_new/v1.4.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ Changes impacting all modules
2929
to work with our estimators and functions.
3030
:pr:`26464` by `Thomas Fan`_.
3131

32+
- |Fix| Fixed a bug in most estimators and functions where setting a parameter to
33+
a large integer would cause a `TypeError`.
34+
:pr:`26648` by :user:`Naoise Holohan <naoise-h>`.
35+
3236
Metadata Routing
3337
----------------
3438

sklearn/utils/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,7 +1094,11 @@ def is_scalar_nan(x):
10941094
>>> is_scalar_nan([np.nan])
10951095
False
10961096
"""
1097-
return isinstance(x, numbers.Real) and math.isnan(x)
1097+
return (
1098+
not isinstance(x, numbers.Integral)
1099+
and isinstance(x, numbers.Real)
1100+
and math.isnan(x)
1101+
)
10981102

10991103

11001104
def _approximate_mode(class_counts, n_draws, rng):

sklearn/utils/_param_validation.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,9 @@ class _NanConstraint(_Constraint):
311311
"""Constraint representing the indicator `np.nan`."""
312312

313313
def is_satisfied_by(self, val):
314-
return isinstance(val, Real) and math.isnan(val)
314+
return (
315+
not isinstance(val, Integral) and isinstance(val, Real) and math.isnan(val)
316+
)
315317

316318
def __str__(self):
317319
return "numpy.nan"
@@ -475,7 +477,7 @@ def _check_params(self):
475477
)
476478

477479
def __contains__(self, val):
478-
if np.isnan(val):
480+
if not isinstance(val, Integral) and np.isnan(val):
479481
return False
480482

481483
left_cmp = operator.lt if self.closed in ("left", "both") else operator.le

sklearn/utils/tests/test_param_validation.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,16 +74,41 @@ def fit(self, X=None, y=None):
7474
def test_interval_range(interval_type):
7575
"""Check the range of values depending on closed."""
7676
interval = Interval(interval_type, -2, 2, closed="left")
77-
assert -2 in interval and 2 not in interval
77+
assert -2 in interval
78+
assert 2 not in interval
7879

7980
interval = Interval(interval_type, -2, 2, closed="right")
80-
assert -2 not in interval and 2 in interval
81+
assert -2 not in interval
82+
assert 2 in interval
8183

8284
interval = Interval(interval_type, -2, 2, closed="both")
83-
assert -2 in interval and 2 in interval
85+
assert -2 in interval
86+
assert 2 in interval
8487

8588
interval = Interval(interval_type, -2, 2, closed="neither")
86-
assert -2 not in interval and 2 not in interval
89+
assert -2 not in interval
90+
assert 2 not in interval
91+
92+
93+
@pytest.mark.parametrize("interval_type", [Integral, Real])
94+
def test_interval_large_integers(interval_type):
95+
"""Check that Interval constraint work with large integers.
96+
97+
non-regression test for #26648.
98+
"""
99+
interval = Interval(interval_type, 0, 2, closed="neither")
100+
assert 2**65 not in interval
101+
assert 2**128 not in interval
102+
assert float(2**65) not in interval
103+
assert float(2**128) not in interval
104+
105+
interval = Interval(interval_type, 0, 2**128, closed="neither")
106+
assert 2**65 in interval
107+
assert 2**128 not in interval
108+
assert float(2**65) in interval
109+
assert float(2**128) not in interval
110+
111+
assert 2**1024 not in interval
87112

88113

89114
def test_interval_inf_in_bounds():
@@ -389,6 +414,7 @@ def test_generate_valid_param(constraint):
389414
("verbose", 1),
390415
(MissingValues(), -1),
391416
(MissingValues(), -1.0),
417+
(MissingValues(), 2**1028),
392418
(MissingValues(), None),
393419
(MissingValues(), float("nan")),
394420
(MissingValues(), np.nan),

0 commit comments

Comments
 (0)