Skip to content

MAINT Added Parameter Validation for datasets.make_circles() #25848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions sklearn/datasets/_samples_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,15 @@ def make_regression(
return X, y


@validate_params(
{
"n_samples": [Interval(Integral, 0, None, closed="left"), tuple],
"shuffle": ["boolean"],
"noise": [Interval(Real, 0, None, closed="left"), None],
"random_state": ["random_state"],
"factor": [Interval(Real, 0, 1, closed="left")],
}
)
def make_circles(
n_samples=100, *, shuffle=True, noise=None, random_state=None, factor=0.8
):
Expand Down Expand Up @@ -706,7 +715,7 @@ def make_circles(
See :term:`Glossary <random_state>`.

factor : float, default=.8
Scale factor between inner and outer circle in the range `(0, 1)`.
Scale factor between inner and outer circle in the range `[0, 1)`.

Returns
-------
Expand All @@ -716,20 +725,13 @@ def make_circles(
y : ndarray of shape (n_samples,)
The integer labels (0 or 1) for class membership of each sample.
"""

if factor >= 1 or factor < 0:
raise ValueError("'factor' has to be between 0 and 1.")

if isinstance(n_samples, numbers.Integral):
n_samples_out = n_samples // 2
n_samples_in = n_samples - n_samples_out
else:
try:
n_samples_out, n_samples_in = n_samples
except ValueError as e:
raise ValueError(
"`n_samples` can be either an int or a two-element tuple."
) from e
else: # n_samples is a tuple
if len(n_samples) != 2:
raise ValueError("When a tuple, n_samples must have exactly two elements.")
n_samples_out, n_samples_in = n_samples

generator = check_random_state(random_state)
# so as not to have the first point = last point, we set endpoint=False
Expand Down
13 changes: 1 addition & 12 deletions sklearn/datasets/tests/test_samples_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,11 +688,6 @@ def test_make_circles():
2,
), "Samples not correctly distributed across circles."

with pytest.raises(ValueError):
make_circles(factor=-0.01)
with pytest.raises(ValueError):
make_circles(factor=1.0)


def test_make_circles_unbalanced():
X, y = make_circles(n_samples=(2, 8))
Expand All @@ -704,12 +699,6 @@ def test_make_circles_unbalanced():

with pytest.raises(
ValueError,
match=r"`n_samples` can be either an int " r"or a two-element tuple.",
):
make_circles(n_samples=[1, 2, 3])

with pytest.raises(
ValueError,
match=r"`n_samples` can be either an int " r"or a two-element tuple.",
match="When a tuple, n_samples must have exactly two elements.",
):
make_circles(n_samples=(10,))
1 change: 1 addition & 0 deletions sklearn/tests/test_public_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _check_function_param_validation(
"sklearn.datasets.fetch_olivetti_faces",
"sklearn.datasets.load_svmlight_file",
"sklearn.datasets.load_svmlight_files",
"sklearn.datasets.make_circles",
"sklearn.datasets.make_classification",
"sklearn.datasets.make_friedman1",
"sklearn.datasets.make_sparse_coded_signal",
Expand Down