Skip to content

[WIP] Common test for equivalence between sparse and dense matrices. #7590

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 56 additions & 1 deletion sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _yield_non_meta_checks(name, Estimator):
yield check_sparsify_coefficients

yield check_estimator_sparse_data

yield check_estimator_sparse_dense
# Test that estimators can be pickled, and once pickled
# give the same answer as before.
yield check_estimators_pickle
Expand Down Expand Up @@ -312,6 +312,11 @@ def set_testing_parameters(estimator):
if not isinstance(estimator, ProjectedGradientNMF):
estimator.set_params(solver='cd')

if "KNeighbors" in estimator.__class__.__name__ :
# Override the default 'auto' for sparse dense equivalence
# since only 'brute' algo is used for sparse see #1572
estimator.set_params(algorithm='brute')


class NotAnArray(object):
" An object that is convertable to an array"
Expand Down Expand Up @@ -344,6 +349,7 @@ def check_estimator_sparse_data(name, Estimator):
estimator = Estimator()
set_testing_parameters(estimator)
# fit and predict

try:
with ignore_warnings(category=DeprecationWarning):
estimator.fit(X, y)
Expand Down Expand Up @@ -1554,3 +1560,52 @@ def check_classifiers_regression_target(name, Estimator):
e = Estimator()
msg = 'Unknown label type: '
assert_raises_regex(ValueError, msg, e.fit, X, y)


def check_estimator_sparse_dense(name, Estimator):
rng = np.random.RandomState(0)
X = rng.rand(40, 10)
X[X < .8] = 0
X_csr = sparse.csr_matrix(X)
y = (4 * rng.rand(40)).astype(np.int)
for sparse_format in ['csr', 'csc', 'dok', 'lil', 'coo', 'dia', 'bsr']:
X_sp = X_csr.asformat(sparse_format)
# catch deprecation warnings
with ignore_warnings(category=DeprecationWarning):
if name in ['Scaler', 'StandardScaler']:
estimator = Estimator(with_mean=False)
estimator_sp = Estimator(with_mean=False)
else:
estimator = Estimator()
estimator_sp = Estimator()
set_testing_parameters(estimator)
set_testing_parameters(estimator_sp)
set_random_state(estimator)
set_random_state(estimator_sp)
#print(np.where(X!=X_sp.toarray()))
# fit and predict
try:
with ignore_warnings(category=DeprecationWarning):
estimator_sp.fit(X_sp, y)
estimator.fit(X, y)
if hasattr(estimator, "predict"):
pred = estimator.predict(X)
pred_sp = estimator_sp.predict(X_sp)
assert_array_almost_equal(pred, pred_sp, 2)
assert_equal(pred.shape, pred_sp.shape)
if hasattr(estimator, 'predict_proba'):
probs = estimator.predict_proba(X)
assert_equal(probs.shape, (X.shape[0], 4))
except TypeError as e:
if 'sparse' not in repr(e):
print("Estimator %s doesn't seem to fail gracefully on "
"sparse data: error message state explicitly that "
"sparse input is not supported if this is not the case."
% name)
raise
except Exception:
print("Estimator %s doesn't seem to fail gracefully on "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This is being output when there's an AssertionError, which is a bit yuck)

"sparse data: it should raise a TypeError if sparse input "
"is explicitly not supported." % name)
raise