@@ -1286,20 +1286,50 @@ cdef class Tree:
1286
1286
float64_t[:, :, ::1 ] oob_node_values,
1287
1287
str method,
1288
1288
):
1289
- if issparse(X_test):
1290
- raise (NotImplementedError (" does not support sparse X yet" ))
1291
- if not isinstance (X_test, np.ndarray):
1292
- raise ValueError (" X should be in np.ndarray format, got %s " % type (X_test))
1289
+ cdef intp_t is_sparse = - 1
1290
+ cdef float32_t[:] X_data
1291
+ cdef int32_t[:] X_indices
1292
+ cdef int32_t[:] X_indptr
1293
+ cdef int32_t[:] feature_to_sample
1294
+ cdef float64_t[:] X_sample
1295
+ cdef float64_t feature_value = 0.0
1296
+
1297
+ cdef float32_t[:, :] X_ndarray
1298
+
1293
1299
if X_test.dtype != DTYPE:
1294
1300
raise ValueError (" X.dtype should be np.float32, got %s " % X_test.dtype)
1295
- cdef const float32_t[:, :] X_ndarray = X_test
1301
+ if issparse(X_test):
1302
+ if X_test.format != " csr" :
1303
+ raise ValueError (" X should be in csr_matrix format, got %s " % type (X_test))
1304
+ is_sparse = 1
1305
+ X_data = X_test.data
1306
+ X_indices = X_test.indices
1307
+ X_indptr = X_test.indptr
1308
+ feature_to_sample = np.zeros(X_test.shape[1 ], dtype = np.int32)
1309
+ X_sample = np.zeros(X_test.shape[1 ], dtype = np.float64)
1310
+
1311
+ # Unused
1312
+ X_ndarray = np.zeros((0 , 0 ), dtype = np.float32)
1313
+
1314
+ else :
1315
+ if not isinstance (X_test, np.ndarray):
1316
+ raise ValueError (" X should be in np.ndarray format, got %s " % type (X_test))
1317
+ is_sparse = 0
1318
+ X_ndarray = X_test
1319
+
1320
+ # Unused
1321
+ X_data = np.zeros(0 , dtype = np.float32)
1322
+ X_indices = np.zeros(0 , dtype = np.int32)
1323
+ X_indptr = np.zeros(0 , dtype = np.int32)
1324
+ feature_to_sample = np.zeros(0 , dtype = np.int32)
1325
+ X_sample = np.zeros(0 , dtype = np.float64)
1296
1326
1297
1327
cdef intp_t n_samples = X_test.shape[0 ]
1298
1328
cdef intp_t* n_classes = self .n_classes
1299
1329
cdef intp_t node_count = self .node_count
1300
1330
cdef intp_t n_outputs = self .n_outputs
1301
1331
cdef intp_t max_n_classes = self .max_n_classes
1302
- cdef int k, c, node_idx, sample_idx = 0
1332
+ cdef int k, c, node_idx, sample_idx, idx = 0
1303
1333
cdef float64_t[:, ::1 ] total_oob_weight = np.zeros((node_count, n_outputs), dtype = np.float64)
1304
1334
cdef int node_value_idx = - 1
1305
1335
@@ -1310,6 +1340,11 @@ cdef class Tree:
1310
1340
with nogil:
1311
1341
# pass the oob samples in the tree and count them per node
1312
1342
for sample_idx in range (n_samples):
1343
+ if is_sparse:
1344
+ for idx in range (X_indptr[sample_idx], X_indptr[sample_idx + 1 ]):
1345
+ # Store wich feature of sample_idx is non zero and its value
1346
+ feature_to_sample[X_indices[idx]] = sample_idx
1347
+ X_sample[X_indices[idx]] = X_data[idx]
1313
1348
# root node
1314
1349
node = self .nodes
1315
1350
node_idx = 0
@@ -1329,10 +1364,20 @@ cdef class Tree:
1329
1364
1330
1365
# child nodes
1331
1366
while node.left_child != _TREE_LEAF and node.right_child != _TREE_LEAF:
1332
- if X_ndarray[sample_idx, node.feature] <= node.threshold:
1333
- node_idx = node.left_child
1367
+ if is_sparse:
1368
+ if feature_to_sample[node.feature] == sample_idx:
1369
+ feature_value = X_sample[node.feature]
1370
+ else :
1371
+ feature_value = 0.
1372
+ if feature_value <= node.threshold:
1373
+ node_idx = node.left_child
1374
+ else :
1375
+ node_idx = node.right_child
1334
1376
else :
1335
- node_idx = node.right_child
1377
+ if X_ndarray[sample_idx, node.feature] <= node.threshold:
1378
+ node_idx = node.left_child
1379
+ else :
1380
+ node_idx = node.right_child
1336
1381
if sample_weight[sample_idx] > 0.0 :
1337
1382
has_oob_sample[node_idx] = 1
1338
1383
node = & self .nodes[node_idx]
@@ -1395,12 +1440,12 @@ cdef class Tree:
1395
1440
cdef float64_t[:, ::1 ] y_regression
1396
1441
if self .max_n_classes > 1 :
1397
1442
# Classification
1398
- y_regression = np.zeros((1 , 1 ), dtype = np.float64) # Unused
1443
+ y_regression = np.zeros((0 , 0 ), dtype = np.float64) # Unused
1399
1444
y_classification = np.ascontiguousarray(y_test, dtype = np.intp)
1400
1445
else :
1401
1446
# Regression
1402
1447
y_regression = np.ascontiguousarray(y_test, dtype = np.float64)
1403
- y_classification = np.zeros((1 , 1 ), dtype = np.intp) # Unused
1448
+ y_classification = np.zeros((0 , 0 ), dtype = np.intp) # Unused
1404
1449
1405
1450
cdef float64_t[::1 ] sample_weight_view = np.ascontiguousarray(sample_weight, dtype = np.float64)
1406
1451
self ._compute_oob_node_values_and_predictions(X_test, y_regression, y_classification, sample_weight_view, oob_pred, has_oob_sample, oob_node_values, method)
0 commit comments