Skip to content

Commit c59b2f6

Browse files
amuellerlarsmans
authored andcommitted
ENH get rid of imports in test_common by checking by names, not classes.
1 parent 9d49b6d commit c59b2f6

File tree

1 file changed

+16
-23
lines changed

1 file changed

+16
-23
lines changed

sklearn/tests/test_common.py

+16-23
Original file line numberDiff line numberDiff line change
@@ -38,30 +38,23 @@
3838
from sklearn.svm.base import BaseLibSVM
3939

4040
# import "special" estimators
41-
from sklearn.decomposition import SparseCoder
4241
from sklearn.pls import _PLS, PLSCanonical, PLSRegression, CCA, PLSSVD
43-
from sklearn.ensemble import RandomTreesEmbedding
4442
from sklearn.feature_selection import SelectKBest
45-
from sklearn.dummy import DummyClassifier, DummyRegressor
4643
from sklearn.naive_bayes import MultinomialNB, BernoulliNB
47-
from sklearn.covariance import EllipticEnvelope, EllipticEnvelop
48-
from sklearn.feature_extraction import DictVectorizer, FeatureHasher
49-
from sklearn.feature_extraction.text import TfidfTransformer
5044
from sklearn.kernel_approximation import AdditiveChi2Sampler
51-
from sklearn.preprocessing import (LabelBinarizer, LabelEncoder, Binarizer,
52-
Normalizer, OneHotEncoder)
45+
from sklearn.preprocessing import Binarizer, Normalizer
5346
from sklearn.cluster import (WardAgglomeration, AffinityPropagation,
5447
SpectralClustering)
55-
from sklearn.isotonic import IsotonicRegression
5648
from sklearn.random_projection import (GaussianRandomProjection,
5749
SparseRandomProjection)
5850

5951
from sklearn.cross_validation import train_test_split
6052

61-
dont_test = [SparseCoder, EllipticEnvelope, EllipticEnvelop, DictVectorizer,
62-
LabelBinarizer, LabelEncoder, TfidfTransformer,
63-
IsotonicRegression, OneHotEncoder, RandomTreesEmbedding,
64-
FeatureHasher, DummyClassifier, DummyRegressor]
53+
dont_test = ['SparseCoder', 'EllipticEnvelope', 'EllipticEnvelop',
54+
'DictVectorizer', 'LabelBinarizer', 'LabelEncoder',
55+
'TfidfTransformer', 'IsotonicRegression', 'OneHotEncoder',
56+
'RandomTreesEmbedding', 'FeatureHasher', 'DummyClassifier',
57+
'DummyRegressor']
6558

6659

6760
def test_all_estimators():
@@ -72,7 +65,7 @@ def test_all_estimators():
7265

7366
for name, E in estimators:
7467
# some can just not be sensibly default constructed
75-
if E in dont_test:
68+
if name in dont_test:
7669
continue
7770
# test default-constructibility
7871
# get rid of deprecation warnings
@@ -136,7 +129,7 @@ def test_estimators_sparse_data():
136129
estimators = [(name, E) for name, E in estimators
137130
if issubclass(E, (ClassifierMixin, RegressorMixin))]
138131
for name, Clf in estimators:
139-
if Clf in dont_test:
132+
if name in dont_test:
140133
continue
141134
# catch deprecation warnings
142135
with warnings.catch_warnings(record=True):
@@ -172,7 +165,7 @@ def test_transformers():
172165
for name, Trans in transformers:
173166
trans = None
174167

175-
if Trans in dont_test:
168+
if name in dont_test:
176169
continue
177170
# these don't actually fit the data:
178171
if Trans in [AdditiveChi2Sampler, Binarizer, Normalizer]:
@@ -250,7 +243,7 @@ def test_transformers_sparse_data():
250243
y = (4 * rng.rand(40)).astype(np.int)
251244
estimators = all_estimators(type_filter='transformer')
252245
for name, Trans in estimators:
253-
if Trans in dont_test:
246+
if name in dont_test:
254247
continue
255248
# catch deprecation warnings
256249
with warnings.catch_warnings(record=True):
@@ -304,7 +297,7 @@ def test_estimators_nan_inf():
304297
" transform.")
305298
for X_train in [X_train_nan, X_train_inf]:
306299
for name, Est in estimators:
307-
if Est in dont_test:
300+
if name in dont_test:
308301
continue
309302
if Est in (_PLS, PLSCanonical, PLSRegression, CCA, PLSSVD):
310303
continue
@@ -390,7 +383,7 @@ def test_classifiers_one_label():
390383
error_string_predict = ("Classifier can't predict when only one class is "
391384
"present.")
392385
for name, Clf in classifiers:
393-
if Clf in dont_test:
386+
if name in dont_test:
394387
continue
395388
# catch deprecation warnings
396389
with warnings.catch_warnings(record=True):
@@ -471,7 +464,7 @@ def test_classifiers_train():
471464
n_classes = len(classes)
472465
n_samples, n_features = X.shape
473466
for name, Clf in classifiers:
474-
if Clf in dont_test:
467+
if name in dont_test:
475468
continue
476469
if Clf in [MultinomialNB, BernoulliNB]:
477470
# TODO also test these!
@@ -539,7 +532,7 @@ def test_classifiers_classes():
539532
# TODO: make work with next line :)
540533
#y = y.astype(np.str)
541534
for name, Clf in classifiers:
542-
if Clf in dont_test:
535+
if name in dont_test:
543536
continue
544537
if Clf in [MultinomialNB, BernoulliNB]:
545538
# TODO also test these!
@@ -570,7 +563,7 @@ def test_regressors_int():
570563
X = StandardScaler().fit_transform(X)
571564
y = np.random.randint(2, size=X.shape[0])
572565
for name, Reg in regressors:
573-
if Reg in dont_test or Reg in (CCA,):
566+
if name in dont_test or Reg in (CCA,):
574567
continue
575568
# catch deprecation warnings
576569
with warnings.catch_warnings(record=True):
@@ -605,7 +598,7 @@ def test_regressors_train():
605598
y = StandardScaler().fit_transform(y)
606599
succeeded = True
607600
for name, Reg in regressors:
608-
if Reg in dont_test:
601+
if name in dont_test:
609602
continue
610603
# catch deprecation warnings
611604
with warnings.catch_warnings(record=True):

0 commit comments

Comments
 (0)