Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,7 +1165,15 @@ def _compute_partial_dependence_recursion(self, grid, target_features):
return averaged_predictions

def _more_tags(self):
return {"allow_nan": True}
return {
"allow_nan": True,
"_xfail_checks": {
"check_estimators_pickle": (
"The memory views of the nodes parameter need to be defined"
"as read only in the Cython implementation."
),
},
}

@abstractmethod
def _get_loss(self, sample_weight):
Expand Down
10 changes: 10 additions & 0 deletions sklearn/preprocessing/_polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,3 +936,13 @@ def transform(self, X):
# We chose the last one.
indices = [j for j in range(XBS.shape[1]) if (j + 1) % n_splines != 0]
return XBS[:, indices]

def _more_tags(self):
return {
"_xfail_checks": {
"check_estimators_pickle": (
"Current Scipy implementation of _bsplines does not"
"support const memory views."
),
}
}
26 changes: 15 additions & 11 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ def _yield_checks(estimator):
# Test that estimators can be pickled, and once pickled
# give the same answer as before.
yield check_estimators_pickle
yield partial(check_estimators_pickle, readonly_memmap=True)

yield check_estimator_get_tags_default_keys

Expand Down Expand Up @@ -1870,7 +1871,7 @@ def check_nonsquare_error(name, estimator_orig):


@ignore_warnings
def check_estimators_pickle(name, estimator_orig):
def check_estimators_pickle(name, estimator_orig, readonly_memmap=False):
"""Test that we can pickle all estimators."""
check_methods = ["predict", "transform", "decision_function", "predict_proba"]

Expand Down Expand Up @@ -1899,16 +1900,19 @@ def check_estimators_pickle(name, estimator_orig):
set_random_state(estimator)
estimator.fit(X, y)

# pickle and unpickle!
pickled_estimator = pickle.dumps(estimator)
module_name = estimator.__module__
if module_name.startswith("sklearn.") and not (
"test_" in module_name or module_name.endswith("_testing")
):
# strict check for sklearn estimators that are not implemented in test
# modules.
assert b"version" in pickled_estimator
unpickled_estimator = pickle.loads(pickled_estimator)
if readonly_memmap:
unpickled_estimator = create_memmap_backed_data(estimator)
else:
# pickle and unpickle!
pickled_estimator = pickle.dumps(estimator)
module_name = estimator.__module__
if module_name.startswith("sklearn.") and not (
"test_" in module_name or module_name.endswith("_testing")
):
# strict check for sklearn estimators that are not implemented in test
# modules.
assert b"version" in pickled_estimator
unpickled_estimator = pickle.loads(pickled_estimator)

result = dict()
for method in check_methods:
Expand Down