Skip to content

Commit 347b109

Browse files
FIX fix scipy bug with sp.hstack in ClassifierChain and RegressorChain (#28524)
1 parent 4e82537 commit 347b109

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

sklearn/multioutput.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -669,6 +669,11 @@ def _get_predictions(self, X, *, output_method):
669669
hstack = sp.hstack if sp.issparse(X) else np.hstack
670670
for chain_idx, estimator in enumerate(self.estimators_):
671671
previous_predictions = Y_feature_chain[:, :chain_idx]
672+
# if `X` is a scipy sparse dok_array, we convert it to a sparse
673+
# coo_array format before hstacking, it's faster; see
674+
# https://github.com/scipy/scipy/issues/20060#issuecomment-1937007039:
675+
if sp.issparse(X) and not sp.isspmatrix(X) and X.format == "dok":
676+
X = sp.coo_array(X)
672677
X_aug = hstack((X, previous_predictions))
673678

674679
feature_predictions, _ = _get_response_values(

0 commit comments

Comments
 (0)