Skip to content

Commit 8329b3b

Browse files
committed
add support for scipy sparse matrices
1 parent 4c40c43 commit 8329b3b

File tree

2 files changed

+56
-19
lines changed

2 files changed

+56
-19
lines changed

sklearn/ensemble/tests/test_forest.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,6 @@ def test_non_OOB_unbiased_feature_importances(name, unbiased_importance_attribut
576576
assert not hasattr(clf, "oob_decision_function_")
577577

578578

579-
# TODO before merge: implement unbiased importance for sparse data
580579
@pytest.mark.parametrize("ForestClassifier", FOREST_CLASSIFIERS.values())
581580
@pytest.mark.parametrize("X_type", ["array", "sparse_csr", "sparse_csc"])
582581
@pytest.mark.parametrize(
@@ -607,8 +606,6 @@ def test_non_OOB_unbiased_feature_importances(name, unbiased_importance_attribut
607606
def test_forest_classifier_oob(
608607
ForestClassifier, X, y, X_type, lower_bound_accuracy, oob_score
609608
):
610-
if X_type != "array":
611-
pytest.skip()
612609
"""Check that OOB score is close to score on a test set."""
613610
X = _convert_container(X, constructor_name=X_type)
614611
X_train, X_test, y_train, y_test = train_test_split(
@@ -632,8 +629,6 @@ def test_forest_classifier_oob(
632629
test_score = oob_score(y_test, classifier.predict(X_test))
633630
else:
634631
test_score = classifier.score(X_test, y_test)
635-
print(test_score, classifier.oob_score_)
636-
637632
assert classifier.oob_score_ >= lower_bound_accuracy
638633

639634
abs_diff = abs(test_score - classifier.oob_score_)
@@ -673,8 +668,6 @@ def test_forest_classifier_oob(
673668
def test_forest_regressor_oob(ForestRegressor, X, y, X_type, lower_bound_r2, oob_score):
674669
"""Check that forest-based regressor provide an OOB score close to the
675670
score on a test set."""
676-
if X_type != "array":
677-
pytest.skip()
678671
X = _convert_container(X, constructor_name=X_type)
679672
X_train, X_test, y_train, y_test = train_test_split(
680673
X,
@@ -698,7 +691,6 @@ def test_forest_regressor_oob(ForestRegressor, X, y, X_type, lower_bound_r2, oob
698691
else:
699692
test_score = regressor.score(X_test, y_test)
700693
assert regressor.oob_score_ >= lower_bound_r2
701-
print(test_score, regressor.oob_score_)
702694

703695
assert abs(test_score - regressor.oob_score_) <= 0.1
704696

sklearn/tree/_tree.pyx

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,20 +1286,50 @@ cdef class Tree:
12861286
float64_t[:, :, ::1] oob_node_values,
12871287
str method,
12881288
):
1289-
if issparse(X_test):
1290-
raise(NotImplementedError("does not support sparse X yet"))
1291-
if not isinstance(X_test, np.ndarray):
1292-
raise ValueError("X should be in np.ndarray format, got %s" % type(X_test))
1289+
cdef intp_t is_sparse = -1
1290+
cdef float32_t[:] X_data
1291+
cdef int32_t[:] X_indices
1292+
cdef int32_t[:] X_indptr
1293+
cdef int32_t[:] feature_to_sample
1294+
cdef float64_t[:] X_sample
1295+
cdef float64_t feature_value = 0.0
1296+
1297+
cdef float32_t[:, :] X_ndarray
1298+
12931299
if X_test.dtype != DTYPE:
12941300
raise ValueError("X.dtype should be np.float32, got %s" % X_test.dtype)
1295-
cdef const float32_t[:, :] X_ndarray = X_test
1301+
if issparse(X_test):
1302+
if X_test.format != "csr":
1303+
raise ValueError("X should be in csr_matrix format, got %s" % type(X_test))
1304+
is_sparse = 1
1305+
X_data = X_test.data
1306+
X_indices = X_test.indices
1307+
X_indptr = X_test.indptr
1308+
feature_to_sample = np.zeros(X_test.shape[1], dtype=np.int32)
1309+
X_sample = np.zeros(X_test.shape[1], dtype=np.float64)
1310+
1311+
# Unused
1312+
X_ndarray = np.zeros((0, 0), dtype=np.float32)
1313+
1314+
else:
1315+
if not isinstance(X_test, np.ndarray):
1316+
raise ValueError("X should be in np.ndarray format, got %s" % type(X_test))
1317+
is_sparse = 0
1318+
X_ndarray = X_test
1319+
1320+
# Unused
1321+
X_data = np.zeros(0, dtype=np.float32)
1322+
X_indices = np.zeros(0, dtype=np.int32)
1323+
X_indptr = np.zeros(0, dtype=np.int32)
1324+
feature_to_sample = np.zeros(0, dtype=np.int32)
1325+
X_sample = np.zeros(0, dtype=np.float64)
12961326

12971327
cdef intp_t n_samples = X_test.shape[0]
12981328
cdef intp_t* n_classes = self.n_classes
12991329
cdef intp_t node_count = self.node_count
13001330
cdef intp_t n_outputs = self.n_outputs
13011331
cdef intp_t max_n_classes = self.max_n_classes
1302-
cdef int k, c, node_idx, sample_idx = 0
1332+
cdef int k, c, node_idx, sample_idx, idx = 0
13031333
cdef float64_t[:, ::1] total_oob_weight = np.zeros((node_count, n_outputs), dtype=np.float64)
13041334
cdef int node_value_idx = -1
13051335

@@ -1310,6 +1340,11 @@ cdef class Tree:
13101340
with nogil:
13111341
# pass the oob samples in the tree and count them per node
13121342
for sample_idx in range(n_samples):
1343+
if is_sparse:
1344+
for idx in range(X_indptr[sample_idx], X_indptr[sample_idx + 1]):
1345+
# Store wich feature of sample_idx is non zero and its value
1346+
feature_to_sample[X_indices[idx]] = sample_idx
1347+
X_sample[X_indices[idx]] = X_data[idx]
13131348
# root node
13141349
node = self.nodes
13151350
node_idx = 0
@@ -1329,10 +1364,20 @@ cdef class Tree:
13291364

13301365
# child nodes
13311366
while node.left_child != _TREE_LEAF and node.right_child != _TREE_LEAF:
1332-
if X_ndarray[sample_idx, node.feature] <= node.threshold:
1333-
node_idx = node.left_child
1367+
if is_sparse:
1368+
if feature_to_sample[node.feature] == sample_idx:
1369+
feature_value = X_sample[node.feature]
1370+
else:
1371+
feature_value = 0.
1372+
if feature_value <= node.threshold:
1373+
node_idx = node.left_child
1374+
else:
1375+
node_idx = node.right_child
13341376
else:
1335-
node_idx = node.right_child
1377+
if X_ndarray[sample_idx, node.feature] <= node.threshold:
1378+
node_idx = node.left_child
1379+
else:
1380+
node_idx = node.right_child
13361381
if sample_weight[sample_idx] > 0.0:
13371382
has_oob_sample[node_idx] = 1
13381383
node = &self.nodes[node_idx]
@@ -1395,12 +1440,12 @@ cdef class Tree:
13951440
cdef float64_t[:, ::1] y_regression
13961441
if self.max_n_classes > 1:
13971442
# Classification
1398-
y_regression = np.zeros((1, 1), dtype=np.float64) # Unused
1443+
y_regression = np.zeros((0, 0), dtype=np.float64) # Unused
13991444
y_classification = np.ascontiguousarray(y_test, dtype=np.intp)
14001445
else:
14011446
# Regression
14021447
y_regression = np.ascontiguousarray(y_test, dtype=np.float64)
1403-
y_classification = np.zeros((1, 1), dtype=np.intp) # Unused
1448+
y_classification = np.zeros((0, 0), dtype=np.intp) # Unused
14041449

14051450
cdef float64_t[::1] sample_weight_view = np.ascontiguousarray(sample_weight, dtype=np.float64)
14061451
self._compute_oob_node_values_and_predictions(X_test, y_regression, y_classification, sample_weight_view, oob_pred, has_oob_sample, oob_node_values, method)

0 commit comments

Comments
 (0)