Skip to content

raise DeprecationWarnings and FutureWarnings as errors #11570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Jul 17, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
c3af3e2
raise DeprecationWarnings and FutureWarnings as errors
amueller Jul 16, 2018
3cb2cf5
fix deprecation warnings in utils
amueller Jul 16, 2018
6d04e8b
fix preprocessing deprecation/future warnings
amueller Jul 16, 2018
6409f29
skip LSH tests because deprecated
amueller Jul 16, 2018
85f910f
deprecations in neighbors module
amueller Jul 16, 2018
3f4a6d2
pass max_iter to PAClassifier
amueller Jul 16, 2018
2fb0b97
lars positive deprecated
amueller Jul 16, 2018
477ca01
randomized lasso deprecated
amueller Jul 16, 2018
eb92a69
gridsearch cv iid in ridge tests
amueller Jul 16, 2018
e0bc687
more linear model deprecations
amueller Jul 16, 2018
506cf6d
pass through dtypes when filtering columns so we can transform featur…
amueller Jul 16, 2018
b12b9fb
catch scipy face deprecation warnings
amueller Jul 16, 2018
f0bd420
catch gridsearch deprecation warnings in text feature extraction
amueller Jul 16, 2018
5a106d2
remove some deprecation warnings from ensemble module
amueller Jul 16, 2018
4af1f20
decomposition deprecations caught
amueller Jul 16, 2018
d93115d
fix deprecation test in covariance
amueller Jul 16, 2018
679568c
filter iid warning
amueller Jul 16, 2018
2daad09
catch warnings in model selection
amueller Jul 16, 2018
b18d3cf
skip more deprecations in model selection
amueller Jul 16, 2018
248d9ad
ignore reorder warning, not sure if it's the right thing to do.
amueller Jul 17, 2018
7e8b259
more iid warnings
amueller Jul 17, 2018
96feb0f
bagging iid fix
amueller Jul 17, 2018
258eb99
fix using imputer on inf and NINF and none in bagging tests.
amueller Jul 17, 2018
de984dd
Merge branch 'simpleimputer_bagging' into test_warnings
amueller Jul 17, 2018
15bf683
add version to deprecation warning comment so we can git grep for it
amueller Jul 17, 2018
56dbf07
ignore iid warning instead of passing iid=False
amueller Jul 17, 2018
db01626
fix warning test
amueller Jul 17, 2018
5c1b4f3
fix calls to ignore_warnings
amueller Jul 17, 2018
0afd1e5
add missing pytest imports
amueller Jul 17, 2018
0718707
don't specify rcond for old numpy, fix deprecation test
amueller Jul 17, 2018
e7f4602
remove duplicate test
amueller Jul 17, 2018
68ea022
fix collections Mapping import
amueller Jul 17, 2018
3ac41dc
hotfixes for scipy atol and botched _fit_and_score
amueller Jul 17, 2018
e52175c
actually fix abc warning, also fix randomized lasso deprecation catch
amueller Jul 17, 2018
cf81437
Merge branch 'master' into test_warnings
amueller Jul 17, 2018
14ec4f3
catch vmeasure warnings
amueller Jul 17, 2018
6a25eb4
random forest n_estimator deprecations
amueller Jul 17, 2018
9e59ba2
add category to ignore_warning(FutureWarnings)
amueller Jul 17, 2018
df339f6
ignore matplotlib warnings
amueller Jul 17, 2018
642b6ac
catch nmi warnings in spectral embedding tests
amueller Jul 17, 2018
022907a
ami nmi deprecation warnings caught
amueller Jul 17, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ addopts =
--doctest-modules
--disable-pytest-warnings
-rs
filterwarnings =
error::DeprecationWarning
error::FutureWarning

[wheelhouse_uploader]
artifact_indexes=
Expand Down
7 changes: 7 additions & 0 deletions sklearn/cluster/tests/test_hierarchical.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
# License: BSD 3 clause
from tempfile import mkdtemp
import shutil
import pytest
from functools import partial

import numpy as np
Expand Down Expand Up @@ -142,6 +143,8 @@ def test_agglomerative_clustering_wrong_arg_memory():
assert_raises(ValueError, clustering.fit, X)


@pytest.mark.filterwarnings("ignore:the behavior of nmi will "
"change in version 0.22")
def test_agglomerative_clustering():
# Check that we obtain the correct number of clusters with
# agglomerative clustering.
Expand Down Expand Up @@ -250,6 +253,8 @@ def test_ward_agglomeration():
assert_raises(ValueError, agglo.fit, X[:0])


@pytest.mark.filterwarnings("ignore:the behavior of nmi will "
"change in version 0.22")
def test_single_linkage_clustering():
# Check that we get the correct result in two emblematic cases
moons, moon_labels = make_moons(noise=0.05, random_state=42)
Expand Down Expand Up @@ -311,6 +316,8 @@ def test_scikit_vs_scipy():
assert_raises(ValueError, _hc_cut, n_leaves + 1, children, n_leaves)


@pytest.mark.filterwarnings("ignore:the behavior of nmi will "
"change in version 0.22")
def test_identical_points():
# Ensure identical points are handled correctly when using mst with
# a sparse connectivity matrix
Expand Down
6 changes: 3 additions & 3 deletions sklearn/covariance/tests/test_graphical_lasso.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
from scipy import linalg
import pytest

from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_array_less
Expand Down Expand Up @@ -151,6 +152,5 @@ def test_deprecated_grid_scores(random_state=1):
"0.19 and will be removed in 0.21. Use "
"``grid_scores_`` instead")

assert_warns_message(DeprecationWarning, depr_message,
lambda: graphical_lasso.grid_scores)
assert_equal(graphical_lasso.grid_scores, graphical_lasso.grid_scores_)
with pytest.warns(DeprecationWarning, match=depr_message):
assert_equal(graphical_lasso.grid_scores, graphical_lasso.grid_scores_)
3 changes: 3 additions & 0 deletions sklearn/decomposition/tests/test_dict_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_dict_learning_overcomplete():
assert_true(dico.components_.shape == (n_components, n_features))


@ignore_warnings(category=DeprecationWarning) # positive lars deprecated 0.22
@pytest.mark.parametrize("transform_algorithm", [
"lasso_lars",
"lasso_cd",
Expand Down Expand Up @@ -170,6 +171,7 @@ def test_dict_learning_online_shapes():
assert_equal(np.dot(code, dictionary).shape, X.shape)


@ignore_warnings(category=DeprecationWarning) # positive lars deprecated 0.22
@pytest.mark.parametrize("transform_algorithm", [
"lasso_lars",
"lasso_cd",
Expand Down Expand Up @@ -306,6 +308,7 @@ def test_sparse_encode_shapes():
assert_equal(code.shape, (n_samples, n_components))


@ignore_warnings(category=DeprecationWarning) # positive lars deprecated 0.22
@pytest.mark.parametrize("positive", [
False,
True,
Expand Down
5 changes: 4 additions & 1 deletion sklearn/decomposition/tests/test_kernel_pca.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import numpy as np
import scipy.sparse as sp
import pytest

from sklearn.utils.testing import (assert_array_almost_equal, assert_less,
assert_equal, assert_not_equal,
assert_raises)
assert_raises, ignore_warnings)

from sklearn.decomposition import PCA, KernelPCA
from sklearn.datasets import make_circles
Expand Down Expand Up @@ -172,6 +173,7 @@ def test_kernel_pca_invalid_kernel():
assert_raises(ValueError, kpca.fit, X_fit)


@pytest.mark.filterwarnings('ignore: The default of the `iid`') # 0.22
def test_gridsearch_pipeline():
# Test if we can do a grid-search to find parameters to separate
# circles with a perceptron model.
Expand All @@ -186,6 +188,7 @@ def test_gridsearch_pipeline():
assert_equal(grid_search.best_score_, 1)


@pytest.mark.filterwarnings('ignore: The default of the `iid`') # 0.22
def test_gridsearch_pipeline_precomputed():
# Test if we can do a grid-search to find parameters to separate
# circles with a perceptron model using a precomputed kernel.
Expand Down
16 changes: 9 additions & 7 deletions sklearn/ensemble/forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,16 +944,17 @@ class labels (multi-output problem).
>>> X, y = make_classification(n_samples=1000, n_features=4,
... n_informative=2, n_redundant=0,
... random_state=0, shuffle=False)
>>> clf = RandomForestClassifier(max_depth=2, random_state=0)
>>> clf = RandomForestClassifier(n_estimators=100, max_depth=2,
... random_state=0)
>>> clf.fit(X, y)
RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
max_depth=2, max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,
min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=1,
oob_score=False, random_state=0, verbose=0, warm_start=False)
>>> print(clf.feature_importances_)
[0.17287856 0.80608704 0.01884792 0.00218648]
[0.14205973 0.76664038 0.0282433 0.06305659]
>>> print(clf.predict([[0, 0, 0, 0]]))
[1]

Expand Down Expand Up @@ -1188,18 +1189,19 @@ class RandomForestRegressor(ForestRegressor):
>>>
>>> X, y = make_regression(n_features=4, n_informative=2,
... random_state=0, shuffle=False)
>>> regr = RandomForestRegressor(max_depth=2, random_state=0)
>>> regr = RandomForestRegressor(max_depth=2, random_state=0,
... n_estimators=100)
>>> regr.fit(X, y)
RandomForestRegressor(bootstrap=True, criterion='mse', max_depth=2,
max_features='auto', max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,
min_weight_fraction_leaf=0.0, n_estimators=100, n_jobs=1,
oob_score=False, random_state=0, verbose=0, warm_start=False)
>>> print(regr.feature_importances_)
[0.17339552 0.81594114 0. 0.01066333]
[0.18146984 0.81473937 0.00145312 0.00233767]
>>> print(regr.predict([[0, 0, 0, 0]]))
[-2.50699856]
[-8.32987858]

Notes
-----
Expand Down
18 changes: 11 additions & 7 deletions sklearn/ensemble/tests/test_bagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# Author: Gilles Louppe
# License: BSD 3 clause

import pytest
import numpy as np

from sklearn.base import BaseEstimator
Expand Down Expand Up @@ -33,7 +34,7 @@
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_boston, load_iris, make_hastie_10_2
from sklearn.utils import check_random_state
from sklearn.preprocessing import Imputer
from sklearn.preprocessing import FunctionTransformer

from scipy.sparse import csc_matrix, csr_matrix

Expand Down Expand Up @@ -496,6 +497,7 @@ def test_parallel_regression():
assert_array_almost_equal(y1, y3)


@pytest.mark.filterwarnings('ignore: The default of the `iid`') # 0.22
def test_gridsearch():
# Check that bagging ensembles can be grid-searched.
# Transform iris into a binary classification task
Expand Down Expand Up @@ -755,6 +757,12 @@ def test_set_oob_score_label_encoding():
assert_equal([x1, x2], [x3, x3])


def replace(X):
X = X.copy().astype('float')
X[~np.isfinite(X)] = 0
return X


def test_bagging_regressor_with_missing_inputs():
# Check that BaggingRegressor can accept X with missing/infinite data
X = np.array([
Expand All @@ -777,9 +785,7 @@ def test_bagging_regressor_with_missing_inputs():
for y in y_values:
regressor = DecisionTreeRegressor()
pipeline = make_pipeline(
Imputer(),
Imputer(missing_values=np.inf),
Imputer(missing_values=np.NINF),
FunctionTransformer(replace, validate=False),
regressor
)
pipeline.fit(X, y).predict(X)
Expand Down Expand Up @@ -807,9 +813,7 @@ def test_bagging_classifier_with_missing_inputs():
y = np.array([3, 6, 6, 6, 6])
classifier = DecisionTreeClassifier()
pipeline = make_pipeline(
Imputer(),
Imputer(missing_values=np.inf),
Imputer(missing_values=np.NINF),
FunctionTransformer(replace, validate=False),
classifier
)
pipeline.fit(X, y).predict(X)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/tests/test_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,13 +439,13 @@ def check_oob_score_raise_error(name):
def test_oob_score_raise_error(name):
check_oob_score_raise_error(name)


def check_gridsearch(name):
forest = FOREST_CLASSIFIERS[name]()
clf = GridSearchCV(forest, {'n_estimators': (1, 2), 'max_depth': (1, 2)})
clf.fit(iris.data, iris.target)


@pytest.mark.filterwarnings('ignore: The default of the `iid`') # 0.22
@pytest.mark.parametrize('name', FOREST_CLASSIFIERS)
def test_gridsearch(name):
# Check that base trees can be grid-searched.
Expand Down
7 changes: 7 additions & 0 deletions sklearn/ensemble/tests/test_partial_dependence.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Testing for the partial dependence module.
"""
import pytest

import numpy as np
from numpy.testing import assert_array_equal
Expand Down Expand Up @@ -103,6 +104,8 @@ def test_partial_dependecy_input():
assert_raises(ValueError, partial_dependence, clf, [0], grid=grid)


@pytest.mark.filterwarnings('ignore: Using or importing the ABCs from')
# matplotlib Python3.7 warning
@if_matplotlib
def test_plot_partial_dependence():
# Test partial dependence plot function.
Expand Down Expand Up @@ -135,6 +138,8 @@ def test_plot_partial_dependence():
assert all(ax.has_data for ax in axs)


@pytest.mark.filterwarnings('ignore: Using or importing the ABCs from')
# matplotlib Python3.7 warning
@if_matplotlib
def test_plot_partial_dependence_input():
# Test partial dependence plot function input checks.
Expand Down Expand Up @@ -170,6 +175,8 @@ def test_plot_partial_dependence_input():
clf, X, [{'foo': 'bar'}])


@pytest.mark.filterwarnings('ignore: Using or importing the ABCs from')
# matplotlib Python3.7 warning
@if_matplotlib
def test_plot_partial_dependence_multiclass():
# Test partial dependence plot function on multi-class input.
Expand Down
6 changes: 4 additions & 2 deletions sklearn/ensemble/tests/test_voting_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,10 +371,12 @@ def test_set_estimator_none():
X1 = np.array([[1], [2]])
y1 = np.array([1, 2])
eclf1 = VotingClassifier(estimators=[('rf', clf2), ('nb', clf3)],
voting='soft', weights=[0, 0.5]).fit(X1, y1)
voting='soft', weights=[0, 0.5],
flatten_transform=False).fit(X1, y1)

eclf2 = VotingClassifier(estimators=[('rf', clf2), ('nb', clf3)],
voting='soft', weights=[1, 0.5])
voting='soft', weights=[1, 0.5],
flatten_transform=False)
eclf2.set_params(rf=None).fit(X1, y1)
assert_array_almost_equal(eclf1.transform(X1),
np.array([[[0.7, 0.3], [0.3, 0.7]],
Expand Down
3 changes: 2 additions & 1 deletion sklearn/ensemble/tests/test_weight_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from sklearn.utils.testing import assert_array_equal, assert_array_less
from sklearn.utils.testing import assert_array_almost_equal
from sklearn.utils.testing import assert_array_almost_equal, ignore_warnings
from sklearn.utils.testing import assert_equal, assert_true, assert_greater
from sklearn.utils.testing import assert_raises, assert_raises_regexp

Expand Down Expand Up @@ -196,6 +196,7 @@ def test_staged_predict():
assert_array_almost_equal(score, staged_scores[-1])


@pytest.mark.filterwarnings('ignore: The default of the `iid`') # 0.22
def test_gridsearch():
# Check that base trees can be grid-searched.
# AdaBoost classification
Expand Down
2 changes: 1 addition & 1 deletion sklearn/ensemble/voting_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class VotingClassifier(_BaseComposition, ClassifierMixin, TransformerMixin):
>>> from sklearn.naive_bayes import GaussianNB
>>> from sklearn.ensemble import RandomForestClassifier, VotingClassifier
>>> clf1 = LogisticRegression(random_state=1)
>>> clf2 = RandomForestClassifier(random_state=1)
>>> clf2 = RandomForestClassifier(n_estimators=50, random_state=1)
>>> clf3 = GaussianNB()
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
>>> y = np.array([1, 1, 1, 2, 2, 2])
Expand Down
6 changes: 4 additions & 2 deletions sklearn/feature_extraction/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from sklearn.feature_extraction.image import (
img_to_graph, grid_to_graph, extract_patches_2d,
reconstruct_from_patches_2d, PatchExtractor, extract_patches)
from sklearn.utils.testing import assert_equal, assert_true, assert_raises
from sklearn.utils.testing import (assert_equal, assert_true, assert_raises,
ignore_warnings)


def test_img_to_graph():
Expand Down Expand Up @@ -55,6 +56,7 @@ def test_grid_to_graph():
assert_true(A.dtype == np.float64)


@ignore_warnings(category=DeprecationWarning) # scipy deprecation inside face
def test_connect_regions():
try:
face = sp.face(gray=True)
Expand All @@ -67,7 +69,7 @@ def test_connect_regions():
graph = img_to_graph(face, mask)
assert_equal(ndimage.label(mask)[1], connected_components(graph)[0])


@ignore_warnings(category=DeprecationWarning) # scipy deprecation inside face
def test_connect_regions_with_grid():
try:
face = sp.face(gray=True)
Expand Down
2 changes: 2 additions & 0 deletions sklearn/feature_extraction/tests/test_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,6 +730,7 @@ def test_vectorizer_inverse_transform(Vectorizer):
assert_array_equal(np.sort(terms), np.sort(terms2))


@pytest.mark.filterwarnings('ignore: The default of the `iid`') # 0.22
def test_count_vectorizer_pipeline_grid_selection():
# raw documents
data = JUNK_FOOD_DOCS + NOTJUNK_FOOD_DOCS
Expand Down Expand Up @@ -766,6 +767,7 @@ def test_count_vectorizer_pipeline_grid_selection():
assert_equal(best_vectorizer.ngram_range, (1, 1))


@pytest.mark.filterwarnings('ignore: The default of the `iid`') # 0.22
def test_vectorizer_pipeline_grid_selection():
# raw documents
data = JUNK_FOOD_DOCS + NOTJUNK_FOOD_DOCS
Expand Down
4 changes: 3 additions & 1 deletion sklearn/feature_extraction/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import unicode_literals, division

import array
from collections import Mapping, defaultdict
from collections import defaultdict
import numbers
from operator import itemgetter
import re
Expand All @@ -32,6 +32,8 @@
from .stop_words import ENGLISH_STOP_WORDS
from ..utils.validation import check_is_fitted, check_array, FLOAT_DTYPES
from ..utils.fixes import sp_version
from ..utils.fixes import _Mapping as Mapping # noqa


__all__ = ['CountVectorizer',
'ENGLISH_STOP_WORDS',
Expand Down
4 changes: 2 additions & 2 deletions sklearn/feature_selection/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def transform(self, X):
X_r : array of shape [n_samples, n_selected_features]
The input samples with only the selected features.
"""
X = check_array(X, accept_sparse='csr')
X = check_array(X, dtype=None, accept_sparse='csr')
mask = self.get_support()
if not mask.any():
warn("No features were selected: either the data is"
Expand Down Expand Up @@ -111,7 +111,7 @@ def inverse_transform(self, X):
return Xt

support = self.get_support()
X = check_array(X)
X = check_array(X, dtype=None)
if support.sum() != X.shape[1]:
raise ValueError("X has a different shape than during fitting.")

Expand Down
Loading