diff --git a/sklearn/_loss/tests/test_link.py b/sklearn/_loss/tests/test_link.py index 435361eaa50f1..c083883d3d650 100644 --- a/sklearn/_loss/tests/test_link.py +++ b/sklearn/_loss/tests/test_link.py @@ -58,20 +58,21 @@ def test_is_in_range(interval): @pytest.mark.parametrize("link", LINK_FUNCTIONS) -def test_link_inverse_identity(link): +def test_link_inverse_identity(link, global_random_seed): # Test that link of inverse gives identity. - rng = np.random.RandomState(42) + rng = np.random.RandomState(global_random_seed) link = link() n_samples, n_classes = 100, None + # The values for `raw_prediction` are limited from -20 to 20 because in the + # class `LogitLink` the term `expit(x)` comes very close to 1 for large + # positive x and therefore loses precision. if link.is_multiclass: n_classes = 10 - raw_prediction = rng.normal(loc=0, scale=10, size=(n_samples, n_classes)) + raw_prediction = rng.uniform(low=-20, high=20, size=(n_samples, n_classes)) if isinstance(link, MultinomialLogit): raw_prediction = link.symmetrize_raw_prediction(raw_prediction) else: - # So far, the valid interval of raw_prediction is (-inf, inf) and - # we do not need to distinguish. - raw_prediction = rng.normal(loc=0, scale=10, size=(n_samples)) + raw_prediction = rng.uniform(low=-20, high=20, size=(n_samples)) assert_allclose(link.link(link.inverse(raw_prediction)), raw_prediction) y_pred = link.inverse(raw_prediction)