Skip to content

[MRG + 1] move custom error/warning classes into sklearn.exceptions (and move deprecated away from utils.__init__.py) #4826

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 20, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/developers/performance.rst
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ silently dispatched to ``numpy.dot``. If you want to be sure when the original
activate the related warning::

>>> import warnings
>>> from sklearn.utils.validation import NonBLASDotWarning
>>> from sklearn.exceptions import NonBLASDotWarning
>>> warnings.simplefilter('always', NonBLASDotWarning) # doctest: +SKIP

.. _profiling-python-code:
Expand Down
4 changes: 2 additions & 2 deletions doc/developers/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -293,5 +293,5 @@ Warnings and Exceptions

- :class:`deprecated`: Decorator to mark a function or class as deprecated.

- :class:`ConvergenceWarning`: Custom warning to catch convergence problems.
Used in ``sklearn.covariance.graph_lasso``.
- :class:`sklearn.exceptions.ConvergenceWarning`: Custom warning to catch
convergence problems. Used in ``sklearn.covariance.graph_lasso``.
25 changes: 25 additions & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,31 @@ partial dependence
ensemble.partial_dependence.plot_partial_dependence


.. _exceptions_ref:

:mod:`sklearn.exceptions`: Exceptions and warnings
==================================================

.. automodule:: sklearn.exceptions
:no-members:
:no-inherited-members:

.. currentmodule:: sklearn

.. autosummary::
:toctree: generated/
:template: class_without_init.rst

exceptions.NotFittedError
exceptions.ChangedBehaviorWarning
exceptions.ConvergenceWarning
exceptions.DataConversionWarning
exceptions.DataDimensionalityWarning
exceptions.EfficiencyWarning
exceptions.FitFailedWarning
exceptions.NonBLASDotWarning
exceptions.UndefinedMetricWarning

.. _feature_extraction_ref:

:mod:`sklearn.feature_extraction`: Feature Extraction
Expand Down
12 changes: 12 additions & 0 deletions doc/templates/class_without_init.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
:mod:`{{module}}`.{{objname}}
{{ underline }}==============

.. currentmodule:: {{ module }}

.. autoclass:: {{ objname }}

.. include:: {{module}}.{{objname}}.examples

.. raw:: html

<div class="clearer"></div>
2 changes: 1 addition & 1 deletion examples/linear_model/plot_sparse_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
from sklearn.metrics import auc, precision_recall_curve
from sklearn.ensemble import ExtraTreesRegressor
from sklearn.utils.extmath import pinvh
from sklearn.utils import ConvergenceWarning
from sklearn.exceptions import ConvergenceWarning


def mutual_incoherence(X_relevant, X_irelevant):
Expand Down
2 changes: 1 addition & 1 deletion sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@

__all__ = ['calibration', 'cluster', 'covariance', 'cross_decomposition',
'cross_validation', 'datasets', 'decomposition', 'dummy',
'ensemble', 'externals', 'feature_extraction',
'ensemble', 'exceptions', 'externals', 'feature_extraction',
'feature_selection', 'gaussian_process', 'grid_search',
'isotonic', 'kernel_approximation', 'kernel_ridge',
'lda', 'learning_curve',
Expand Down
9 changes: 8 additions & 1 deletion sklearn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,18 @@
from scipy import sparse
from .externals import six
from .utils.fixes import signature
from .utils.deprecation import deprecated
from .exceptions import ChangedBehaviorWarning as ChangedBehaviorWarning_


class ChangedBehaviorWarning(UserWarning):
class ChangedBehaviorWarning(ChangedBehaviorWarning_):
pass

ChangedBehaviorWarning = deprecated("ChangedBehaviorWarning has been moved "
"into the sklearn.exceptions module. "
"It will not be available here from "
"version 0.19")(ChangedBehaviorWarning)


##############################################################################
def clone(estimator, safe=True):
Expand Down
3 changes: 2 additions & 1 deletion sklearn/cluster/birch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from ..externals.six.moves import xrange
from ..utils import check_array
from ..utils.extmath import row_norms, safe_sparse_dot
from ..utils.validation import NotFittedError, check_is_fitted
from ..utils.validation import check_is_fitted
from ..exceptions import NotFittedError
from .hierarchical import AgglomerativeClustering


Expand Down
2 changes: 1 addition & 1 deletion sklearn/cluster/tests/test_k_means.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from sklearn.utils.testing import assert_raise_message


from sklearn.utils.validation import DataConversionWarning
from sklearn.utils.extmath import row_norms
from sklearn.metrics.cluster import v_measure_score
from sklearn.cluster import KMeans, k_means
Expand All @@ -29,6 +28,7 @@
from sklearn.cluster.k_means_ import _mini_batch_step
from sklearn.datasets.samples_generator import make_blobs
from sklearn.externals.six.moves import cStringIO as StringIO
from sklearn.exceptions import DataConversionWarning


# non centered, sparse centers to check the
Expand Down
2 changes: 1 addition & 1 deletion sklearn/covariance/graph_lasso_.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .empirical_covariance_ import (empirical_covariance, EmpiricalCovariance,
log_likelihood)

from ..utils import ConvergenceWarning
from ..exceptions import ConvergenceWarning
from ..utils.extmath import pinvh
from ..utils.validation import check_random_state, check_array
from ..linear_model import lars_path
Expand Down
4 changes: 2 additions & 2 deletions sklearn/covariance/tests/test_robust_covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_raises, assert_warns
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_raise_message
from sklearn.utils.validation import NotFittedError
from sklearn.exceptions import NotFittedError

from sklearn import datasets
from sklearn.covariance import empirical_covariance, MinCovDet, \
Expand Down
5 changes: 1 addition & 4 deletions sklearn/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .metrics.scorer import check_scoring
from .utils.fixes import bincount
from .gaussian_process.kernels import Kernel as GPKernel
from .exceptions import FitFailedWarning

__all__ = ['KFold',
'LabelKFold',
Expand Down Expand Up @@ -1428,10 +1429,6 @@ def cross_val_score(estimator, X, y=None, scoring=None, cv=None, n_jobs=1,
return np.array(scores)[:, 0]


class FitFailedWarning(RuntimeWarning):
pass


def _fit_and_score(estimator, X, y, scorer, train, test, verbose,
parameters, fit_params, return_train_score=False,
return_parameters=False, error_score='raise'):
Expand Down
2 changes: 1 addition & 1 deletion sklearn/decomposition/factor_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..utils import check_array, check_random_state
from ..utils.extmath import fast_logdet, fast_dot, randomized_svd, squared_norm
from ..utils.validation import check_is_fitted
from ..utils import ConvergenceWarning
from ..exceptions import ConvergenceWarning


class FactorAnalysis(BaseEstimator, TransformerMixin):
Expand Down
3 changes: 2 additions & 1 deletion sklearn/decomposition/kernel_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from scipy import linalg

from ..utils.arpack import eigsh
from ..utils.validation import check_is_fitted, NotFittedError
from ..utils.validation import check_is_fitted
from ..exceptions import NotFittedError
from ..base import BaseEstimator, TransformerMixin
from ..preprocessing import KernelCenterer
from ..metrics.pairwise import pairwise_kernels
Expand Down
2 changes: 1 addition & 1 deletion sklearn/decomposition/nmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from ..utils.extmath import fast_dot
from ..utils.validation import check_is_fitted, check_non_negative
from ..utils import deprecated
from ..utils import ConvergenceWarning
from ..exceptions import ConvergenceWarning
from .cdnmf_fast import _update_cdnmf_fast


Expand Down
3 changes: 2 additions & 1 deletion sklearn/decomposition/online_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
from ..base import BaseEstimator, TransformerMixin
from ..utils import (check_random_state, check_array,
gen_batches, gen_even_slices, _get_n_jobs)
from ..utils.validation import NotFittedError, check_non_negative
from ..utils.validation import check_non_negative
from ..utils.extmath import logsumexp
from ..externals.joblib import Parallel, delayed
from ..externals.six.moves import xrange
from ..exceptions import NotFittedError

from ._online_lda import (mean_change, _dirichlet_expectation_1d,
_dirichlet_expectation_2d)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/decomposition/tests/test_factor_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_almost_equal
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils import ConvergenceWarning
from sklearn.exceptions import ConvergenceWarning
from sklearn.decomposition import FactorAnalysis


Expand Down
2 changes: 1 addition & 1 deletion sklearn/decomposition/tests/test_online_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sklearn.utils.testing import assert_raises_regexp
from sklearn.utils.testing import if_safe_multiprocessing_with_blas

from sklearn.utils.validation import NotFittedError
from sklearn.exceptions import NotFittedError
from sklearn.externals.six.moves import xrange


Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class calls the ``fit`` method of each sub-estimator on random samples
ExtraTreeClassifier, ExtraTreeRegressor)
from ..tree._tree import DTYPE, DOUBLE
from ..utils import check_random_state, check_array, compute_sample_weight
from ..utils.validation import DataConversionWarning, NotFittedError
from ..exceptions import DataConversionWarning, NotFittedError
from .base import BaseEnsemble, _partition_estimators
from ..utils.fixes import bincount
from ..utils.multiclass import check_classification_targets
Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@
from ..utils.fixes import bincount
from ..utils.stats import _weighted_percentile
from ..utils.validation import check_is_fitted
from ..utils.validation import NotFittedError
from ..utils.multiclass import check_classification_targets
from ..exceptions import NotFittedError


class QuantileEstimator(BaseEstimator):
Expand Down
5 changes: 2 additions & 3 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,8 @@
from sklearn.utils.testing import assert_raises
from sklearn.utils.testing import assert_true
from sklearn.utils.testing import assert_warns
from sklearn.utils.testing import ignore_warnings
from sklearn.utils.validation import DataConversionWarning
from sklearn.utils.validation import NotFittedError
from sklearn.exceptions import DataConversionWarning
from sklearn.exceptions import NotFittedError

# toy sample
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
Expand Down
117 changes: 117 additions & 0 deletions sklearn/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
The :mod:`sklearn.exceptions` module includes all custom warnings and error
classes used across scikit-learn.
"""

__all__ = ['NotFittedError',
'ChangedBehaviorWarning',
'ConvergenceWarning',
'DataConversionWarning',
'DataDimensionalityWarning',
'EfficiencyWarning',
'FitFailedWarning',
'NonBLASDotWarning',
'UndefinedMetricWarning']


class NotFittedError(ValueError, AttributeError):
"""Exception class to raise if estimator is used before fitting.

This class inherits from both ValueError and AttributeError to help with
exception handling and backward compatibility.

Examples
--------
>>> from sklearn.svm import LinearSVC
>>> from sklearn.exceptions import NotFittedError
>>> try:
... LinearSVC().predict([[1, 2], [2, 3], [3, 4]])
... except NotFittedError as e:
... print(repr(e))
... # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
NotFittedError('This LinearSVC instance is not fitted yet',)
"""


class ChangedBehaviorWarning(UserWarning):
"""Warning class used to notify the user of any change in the behavior."""


class ConvergenceWarning(UserWarning):
"""Custom warning to capture convergence problems"""


class DataConversionWarning(UserWarning):
"""Warning used to notify implicit data conversions happening in the code.

This warning occurs when some input data needs to be converted or
interpreted in a way that may not match the user's expectations.

For example, this warning may occur when the the user
- passes an integer array to a function which expects float input and
will convert the input
- requests a non-copying operation, but a copy is required to meet the
implementation's data-type expectations;
- passes an input whose shape can be interpreted ambiguously.
"""


class DataDimensionalityWarning(UserWarning):
"""Custom warning to notify potential issues with data dimensionality.

For example, in random projection, this warning is raised when the
number of components, which quantifes the dimensionality of the target
projection space, is higher than the number of features, which quantifies
the dimensionality of the original source space, to imply that the
dimensionality of the problem will not be reduced.
"""


class EfficiencyWarning(UserWarning):
"""Warning used to notify the user of inefficient computation.

This warning notifies the user that the efficiency may not be optimal due
to some reason which may be included as a part of the warning message.
This may be subclassed into a more specific Warning class.
"""


class FitFailedWarning(RuntimeWarning):
"""Warning class used if there is an error while fitting the estimator.

This Warning is used in meta estimators GridSearchCV and RandomizedSearchCV
and the cross-validation helper function cross_val_score to warn when there
is an error while fitting the estimator.

Examples
--------
>>> from sklearn.grid_search import GridSearchCV
>>> from sklearn.svm import LinearSVC
>>> from sklearn.exceptions import FitFailedWarning
>>> import warnings
>>> warnings.simplefilter('always', FitFailedWarning)
>>> gs = GridSearchCV(LinearSVC(), {'C': [-1, -2]}, error_score=0)
>>> X, y = [[1, 2], [3, 4], [5, 6], [7, 8], [8, 9]], [0, 0, 0, 1, 1]
>>> with warnings.catch_warnings(record=True) as w:
... try:
... gs.fit(X, y) # This will raise a ValueError since C is < 0
... except ValueError:
... pass
... print(repr(w[-1].message))
... # doctest: +NORMALIZE_WHITESPACE
FitFailedWarning("Classifier fit failed. The score on this train-test
partition for these parameters will be set to 0.000000. Details:
\\nValueError('Penalty term must be positive; got (C=-2)',)",)
"""


class NonBLASDotWarning(EfficiencyWarning):
"""Warning used when the dot operation does not use BLAS.

This warning is used to notify the user that BLAS was not used for dot
operation and hence the efficiency may be affected.
"""


class UndefinedMetricWarning(UserWarning):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this guy should have a description.

pass
6 changes: 3 additions & 3 deletions sklearn/feature_selection/from_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import numpy as np

from .base import SelectorMixin
from ..base import (TransformerMixin, BaseEstimator, clone,
MetaEstimatorMixin)
from ..base import TransformerMixin, BaseEstimator, clone
from ..externals import six

from ..utils import safe_mask, check_array, deprecated
from ..utils.validation import NotFittedError, check_is_fitted
from ..utils.validation import check_is_fitted
from ..exceptions import NotFittedError


def _get_feature_importances(estimator):
Expand Down
Loading