-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+1] Uncontroversial fixes from estimator tags branch #8086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7ef1deb
c99b9ec
534b0c5
c7cd00d
91559ce
6ee218d
27775e9
f4c9d60
0555e22
30bdd04
5ed1174
adee7a3
a83697f
9ce4747
f727e89
8bbb742
746ccdb
e2d0464
9564e0f
8f4ec6a
c712ca3
bb7f085
c18d646
8a3ea13
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
import numpy as np | ||
import itertools | ||
|
||
from sklearn.exceptions import ConvergenceWarning | ||
|
||
|
@@ -25,10 +26,27 @@ | |
X = rng_global.randn(n_samples, n_features) | ||
|
||
|
||
def test_sparse_encode_shapes_omp(): | ||
rng = np.random.RandomState(0) | ||
algorithms = ['omp', 'lasso_lars', 'lasso_cd', 'lars', 'threshold'] | ||
for n_components, n_samples in itertools.product([1, 5], [1, 9]): | ||
X_ = rng.randn(n_samples, n_features) | ||
dictionary = rng.randn(n_components, n_features) | ||
for algorithm, n_jobs in itertools.product(algorithms, [1, 3]): | ||
code = sparse_encode(X_, dictionary, algorithm=algorithm, | ||
n_jobs=n_jobs) | ||
assert_equal(code.shape, (n_samples, n_components)) | ||
|
||
|
||
def test_dict_learning_shapes(): | ||
n_components = 5 | ||
dico = DictionaryLearning(n_components, random_state=0).fit(X) | ||
assert_true(dico.components_.shape == (n_components, n_features)) | ||
assert_equal(dico.components_.shape, (n_components, n_features)) | ||
|
||
n_components = 1 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @GaelVaroquaux this is the regression test for the SparseEncode change. I can add a more direct test on SparseEncode, though. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, I think that this is good. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added more test now, one testing the core issue in |
||
dico = DictionaryLearning(n_components, random_state=0).fit(X) | ||
assert_equal(dico.components_.shape, (n_components, n_features)) | ||
assert_equal(dico.transform(X).shape, (X.shape[0], n_components)) | ||
|
||
|
||
def test_dict_learning_overcomplete(): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -64,7 +64,7 @@ | |
from ..exceptions import NotFittedError | ||
|
||
|
||
class QuantileEstimator(BaseEstimator): | ||
class QuantileEstimator(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this change? (there is probably a good reason, just asking) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are not scikit-learn estimators, they don't fulfill the sklearn estimator API and the inheritance doesn't provide any functionality. So I think having them inherit is rather confusing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They are also not in a position for get_params/set_params to be used. |
||
"""An estimator predicting the alpha-quantile of the training targets.""" | ||
def __init__(self, alpha=0.9): | ||
if not 0 < alpha < 1.0: | ||
|
@@ -86,7 +86,7 @@ def predict(self, X): | |
return y | ||
|
||
|
||
class MeanEstimator(BaseEstimator): | ||
class MeanEstimator(object): | ||
"""An estimator predicting the mean of the training targets.""" | ||
def fit(self, X, y, sample_weight=None): | ||
if sample_weight is None: | ||
|
@@ -102,7 +102,7 @@ def predict(self, X): | |
return y | ||
|
||
|
||
class LogOddsEstimator(BaseEstimator): | ||
class LogOddsEstimator(object): | ||
"""An estimator predicting the log odds ratio.""" | ||
scale = 1.0 | ||
|
||
|
@@ -132,7 +132,7 @@ class ScaledLogOddsEstimator(LogOddsEstimator): | |
scale = 0.5 | ||
|
||
|
||
class PriorProbabilityEstimator(BaseEstimator): | ||
class PriorProbabilityEstimator(object): | ||
"""An estimator predicting the probability of each | ||
class in the training data. | ||
""" | ||
|
@@ -150,7 +150,7 @@ def predict(self, X): | |
return y | ||
|
||
|
||
class ZeroEstimator(BaseEstimator): | ||
class ZeroEstimator(object): | ||
"""An estimator that simply predicts zero. """ | ||
|
||
def fit(self, X, y, sample_weight=None): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could use
check_consistent_length(X.T, dictionary.T)
You could argue that's less clear though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'll say "different number of samples" in the error, right?