Skip to content

Warning the user of bad default values, starting by dbscan.eps #14942

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
import sys
import logging
import os
import warnings
from .exceptions import BadDefaultWarning

from ._config import get_config, set_config, config_context

logger = logging.getLogger(__name__)
warnings.filterwarnings("once", category=BadDefaultWarning)


# PEP0440 compatible formatted version, see:
Expand Down
19 changes: 16 additions & 3 deletions sklearn/cluster/_dbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from ..base import BaseEstimator, ClusterMixin
from ..utils.validation import _check_sample_weight, _deprecate_positional_args
from ..neighbors import NearestNeighbors
from ..utils.validation import _validate_bad_defaults

from ._dbscan_inner import dbscan_inner

Expand All @@ -42,6 +43,9 @@ def dbscan(X, eps=0.5, *, min_samples=5, metric='minkowski',
important DBSCAN parameter to choose appropriately for your data set
and distance function.

Note that there is no good default value for this parameter. An
optimal value depends on the data at hand as well as the used metric.

min_samples : int, default=5
The number of samples (or total weight) in a neighborhood for a point
to be considered as a core point. This includes the point itself.
Expand Down Expand Up @@ -165,6 +169,11 @@ class DBSCAN(ClusterMixin, BaseEstimator):
important DBSCAN parameter to choose appropriately for your data set
and distance function.

Note that there is no good default value for this parameter. An
optimal value depends on the data at hand as well as the used metric.
If not specified, a warning is raised and the default value of 0.5 is
used.

min_samples : int, default=5
The number of samples (or total weight) in a neighborhood for a point
to be considered as a core point. This includes the point itself.
Expand Down Expand Up @@ -271,8 +280,11 @@ class DBSCAN(ClusterMixin, BaseEstimator):
DBSCAN revisited, revisited: why and how you should (still) use DBSCAN.
ACM Transactions on Database Systems (TODS), 42(3), 19.
"""

_bad_defaults = {'eps': 0.5}

@_deprecate_positional_args
def __init__(self, eps=0.5, *, min_samples=5, metric='euclidean',
def __init__(self, eps='warn', *, min_samples=5, metric='euclidean',
metric_params=None, algorithm='auto', leaf_size=30, p=None,
n_jobs=None):
self.eps = eps
Expand Down Expand Up @@ -310,8 +322,9 @@ def fit(self, X, y=None, sample_weight=None):

"""
X = self._validate_data(X, accept_sparse='csr')
eps = _validate_bad_defaults(self)['eps']

if not self.eps > 0.0:
if not eps > 0.0:
raise ValueError("eps must be positive.")

if sample_weight is not None:
Expand All @@ -328,7 +341,7 @@ def fit(self, X, y=None, sample_weight=None):
X.setdiag(X.diagonal()) # XXX: modifies X's internals in-place

neighbors_model = NearestNeighbors(
radius=self.eps, algorithm=self.algorithm,
radius=eps, algorithm=self.algorithm,
leaf_size=self.leaf_size, metric=self.metric,
metric_params=self.metric_params, p=self.p, n_jobs=self.n_jobs)
neighbors_model.fit(X)
Expand Down
14 changes: 13 additions & 1 deletion sklearn/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
'NonBLASDotWarning',
'SkipTestWarning',
'UndefinedMetricWarning',
'PositiveSpectrumWarning']
'PositiveSpectrumWarning',
'BadDefaultWarning']


class NotFittedError(ValueError, AttributeError):
Expand Down Expand Up @@ -147,3 +148,14 @@ class PositiveSpectrumWarning(UserWarning):

.. versionadded:: 0.22
"""


class BadDefaultWarning(UserWarning):
"""Warning raised for unspecified parameters with no good default.

This warning is typically raised by _validate_bad_defaults when the user
does not specify a value for a parameter with no good default value. An
example is the ``eps`` in :class:`cluster.DBSCAN`.

.. versionadded: 0.24
"""
51 changes: 51 additions & 0 deletions sklearn/utils/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
import scipy.sparse as sp

from sklearn.base import BaseEstimator
from sklearn.utils._testing import assert_no_warnings
from sklearn.utils._testing import ignore_warnings
from sklearn.utils._testing import SkipTest
Expand Down Expand Up @@ -43,12 +44,14 @@
_deprecate_positional_args,
_check_sample_weight,
_allclose_dense_sparse,
_validate_bad_defaults,
FLOAT_DTYPES)
from sklearn.utils.validation import _check_fit_params

import sklearn

from sklearn.exceptions import NotFittedError, PositiveSpectrumWarning
from sklearn.exceptions import BadDefaultWarning

from sklearn.utils._testing import TempMemmap

Expand Down Expand Up @@ -1107,6 +1110,54 @@ def test_allclose_dense_sparse_raise(toarray):
_allclose_dense_sparse(x, y)


def test_validate_bad_params():
msg1 = ("There is no good default value for the following parameters in "
"A. Please consult the documentation on how to set them for your "
"data."
"\n 'param_a' - using default value: 1"
"\n 'param_b' - using default value: 'kmeans'")
msg2 = ("There is no good default value for the following parameters in "
"A. Please consult the documentation on how to set them for your "
"data."
"\n 'param_b' - using default value: 'kmeans'")

class A(BaseEstimator):
# The param_c should not warn as a result of _validate_bad_defaults
# since it's not included in _bad_defaults
_bad_defaults = {'param_a': 1, 'param_b': 'kmeans'}

def __init__(self, param_a='warn', param_b='warn', param_c='warn',
param_d=0):
self.param_a = param_a
self.param_b = param_b
self.param_c = param_c
self.param_d = param_d

def fit(self, X=None, y=None):
_validate_bad_defaults(self)
return self

with pytest.warns(BadDefaultWarning, match=msg1):
A().fit()

# should not warn the second time
with warnings.catch_warnings(record=True) as warns:
A().fit()
assert not warns

with pytest.warns(BadDefaultWarning, match=msg2):
A(param_a=1).fit()

# should not warn the second time
with warnings.catch_warnings(record=True) as warns:
A(param_a=1).fit()
assert not warns

with warnings.catch_warnings(record=True) as warns:
A(param_a=1, param_b='dbscan').fit()
assert not warns


def test_deprecate_positional_args_warns_for_function():

@_deprecate_positional_args
Expand Down
22 changes: 22 additions & 0 deletions sklearn/utils/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ..exceptions import NonBLASDotWarning, PositiveSpectrumWarning
from ..exceptions import NotFittedError
from ..exceptions import DataConversionWarning
from ..exceptions import BadDefaultWarning

FLOAT_DTYPES = (np.float64, np.float32, np.float16)

Expand Down Expand Up @@ -1342,6 +1343,27 @@ def _allclose_dense_sparse(x, y, rtol=1e-7, atol=1e-9):
"matrix and an array")


def _validate_bad_defaults(obj):
if not hasattr(obj, "_bad_defaults"):
return

obj_values = {param: getattr(obj, param) for param in obj._bad_defaults}
bad_params = sorted([param for param, value in obj_values.items()
if value == 'warn'])
if bad_params:
msg = ("There is no good default value for the following "
"parameters in {}. Please consult the documentation "
"on how to set them for your data.\n ".format(
obj.__class__.__name__))
msg += '\n '.join(["'{}' - using default value: {!r}".format(
param, obj._bad_defaults[param]) for param in bad_params])
warnings.warn(msg, BadDefaultWarning)
all_params = obj.get_params()
for param in bad_params:
all_params[param] = obj._bad_defaults[param]
return all_params


def _check_fit_params(X, fit_params, indices=None):
"""Check and validate the parameters passed during `fit`.

Expand Down