Skip to content

Commit a76b5be

Browse files
committed
don't to input validation in each tree for RandomForest.predict
1 parent 6d604c9 commit a76b5be

File tree

2 files changed

+31
-16
lines changed

2 files changed

+31
-16
lines changed

sklearn/ensemble/forest.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,8 @@ def _set_oob_score(self, X, y):
365365
mask = np.ones(n_samples, dtype=np.bool)
366366
mask[estimator.indices_] = False
367367
mask_indices = sample_indices[mask]
368-
p_estimator = estimator.predict_proba(X[mask_indices, :])
368+
p_estimator = estimator.predict_proba(X[mask_indices, :],
369+
check_input=False)
369370

370371
if self.n_outputs_ == 1:
371372
p_estimator = [p_estimator]
@@ -508,7 +509,7 @@ class in a leaf.
508509
# Parallel loop
509510
all_proba = Parallel(n_jobs=n_jobs, verbose=self.verbose,
510511
backend="threading")(
511-
delayed(_parallel_helper)(e, 'predict_proba', X)
512+
delayed(_parallel_helper)(e, 'predict_proba', X, check_input=False)
512513
for e in self.estimators_)
513514

514515
# Reduce
@@ -614,6 +615,10 @@ def predict(self, X):
614615

615616
# Check data
616617
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
618+
if issparse(X) and (X.indices.dtype != np.intc or
619+
X.indptr.dtype != np.intc):
620+
raise ValueError("No support for np.int64 index based "
621+
"sparse matrices")
617622

618623
# Assign chunk of trees to jobs
619624
n_jobs, n_trees, starts = _partition_estimators(self.n_estimators,
@@ -622,7 +627,7 @@ def predict(self, X):
622627
# Parallel loop
623628
all_y_hat = Parallel(n_jobs=n_jobs, verbose=self.verbose,
624629
backend="threading")(
625-
delayed(_parallel_helper)(e, 'predict', X)
630+
delayed(_parallel_helper)(e, 'predict', X, check_input=False)
626631
for e in self.estimators_)
627632

628633
# Reduce
@@ -642,7 +647,7 @@ def _set_oob_score(self, X, y):
642647
mask = np.ones(n_samples, dtype=np.bool)
643648
mask[estimator.indices_] = False
644649
mask_indices = sample_indices[mask]
645-
p_estimator = estimator.predict(X[mask_indices, :])
650+
p_estimator = estimator.predict(X[mask_indices, :], check_input=False)
646651

647652
if self.n_outputs_ == 1:
648653
p_estimator = p_estimator[:, np.newaxis]

sklearn/tree/tree.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def fit(self, X, y, sample_weight=None, check_input=True):
309309

310310
return self
311311

312-
def predict(self, X):
312+
def predict(self, X, check_input=True):
313313
"""Predict class or regression value for X.
314314
315315
For a classification model, the predicted class for each sample in X is
@@ -323,16 +323,21 @@ def predict(self, X):
323323
``dtype=np.float32`` and if a sparse matrix is provided
324324
to a sparse ``csr_matrix``.
325325
326+
check_input : boolean, (default=True)
327+
Allow to bypass several input checking.
328+
Don't use this parameter unless you know what you do.
329+
326330
Returns
327331
-------
328332
y : array of shape = [n_samples] or [n_samples, n_outputs]
329333
The predicted classes, or the predict values.
330334
"""
331-
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
332-
if issparse(X) and (X.indices.dtype != np.intc or
333-
X.indptr.dtype != np.intc):
334-
raise ValueError("No support for np.int64 index based "
335-
"sparse matrices")
335+
if check_input:
336+
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
337+
if issparse(X) and (X.indices.dtype != np.intc or
338+
X.indptr.dtype != np.intc):
339+
raise ValueError("No support for np.int64 index based "
340+
"sparse matrices")
336341

337342
n_samples, n_features = X.shape
338343

@@ -541,12 +546,16 @@ def __init__(self,
541546
class_weight=class_weight,
542547
random_state=random_state)
543548

544-
def predict_proba(self, X):
549+
def predict_proba(self, X, check_input=True):
545550
"""Predict class probabilities of the input samples X.
546551
547552
The predicted class probability is the fraction of samples of the same
548553
class in a leaf.
549554
555+
check_input : boolean, (default=True)
556+
Allow to bypass several input checking.
557+
Don't use this parameter unless you know what you do.
558+
550559
Parameters
551560
----------
552561
X : array-like or sparse matrix of shape = [n_samples, n_features]
@@ -562,11 +571,12 @@ class in a leaf.
562571
classes corresponds to that in the attribute `classes_`.
563572
"""
564573
check_is_fitted(self, 'n_outputs_')
565-
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
566-
if issparse(X) and (X.indices.dtype != np.intc or
567-
X.indptr.dtype != np.intc):
568-
raise ValueError("No support for np.int64 index based "
569-
"sparse matrices")
574+
if check_input:
575+
X = check_array(X, dtype=DTYPE, accept_sparse="csr")
576+
if issparse(X) and (X.indices.dtype != np.intc or
577+
X.indptr.dtype != np.intc):
578+
raise ValueError("No support for np.int64 index based "
579+
"sparse matrices")
570580

571581
n_samples, n_features = X.shape
572582

0 commit comments

Comments
 (0)