Skip to content

Run itrees in parallel during prediction. #14001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
4 changes: 4 additions & 0 deletions doc/whats_new/v0.22.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ Changelog
validation data separately to avoid any data leak. :pr:`13933` by
`NicolasHug`_.

- |Fix| :class:`ensemble.IsolationForest` now runs parallel jobs
during :term:`predict`. :pr:`14001` by :user:`Sérgio Pereira
<sergiormpereira>`.

:mod:`sklearn.linear_model`
..................

Expand Down
61 changes: 42 additions & 19 deletions sklearn/ensemble/iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
# Alexandre Gramfort <alexandre.gramfort@telecom-paristech.fr>
# License: BSD 3 clause

from distutils.version import LooseVersion

import joblib
from joblib import Parallel, delayed
import numbers
import numpy as np
from scipy.sparse import issparse
Expand All @@ -20,6 +23,7 @@
from ..base import OutlierMixin

from .bagging import BaseBagging
from .base import _partition_estimators

__all__ = ["IsolationForest"]

Expand Down Expand Up @@ -414,37 +418,56 @@ def _compute_chunked_score_samples(self, X):
return scores

def _compute_score_samples(self, X, subsample_features):
"""Compute the score of each samples in X going through the extra trees.
"""Compute the score of each samples in X going through the extra
trees.

Parameters
----------
X : array-like or sparse matrix

subsample_features : bool,
whether features should be subsampled

Returns
-------
ndarray
The anomaly scores for each sample
"""
n_samples = X.shape[0]
def get_depths(_X, trees, trees_features, _subsample_features):
n = _X.shape[0]
batch_depths = np.zeros(n, order="f")

for tree, features in zip(trees, trees_features):
X_subset = _X[:, features] if _subsample_features else _X

leaves_index = tree.apply(X_subset)
node_indicator = tree.decision_path(X_subset)
n_samples_leaf = tree.tree_.n_node_samples[leaves_index]

batch_depths += np.ravel(node_indicator.sum(axis=1)) \
+ _average_path_length(n_samples_leaf) - 1.0

return batch_depths

n_jobs, n_estimators, starts = _partition_estimators(
self.n_estimators, self.n_jobs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking at this issue again, I think we can do this:

Suggested change
self.n_estimators, self.n_jobs)
self.n_estimators, None)

which allows joblib.parallel_backend to control n_jobs. At a higher level, we can use parallel_backend to control n_jobs:

with parallel_backend("loky", n_jobs=6):
    iso.score_samples(X)

Details: Seeing that _partition_estimators uses effective_n_jobs:

n_jobs = min(effective_n_jobs(n_jobs), n_estimators)

effective_n_jobs queries configuration as follows:

with parallel_backend("loky", n_jobs=4):
    print(effective_n_jobs(None))
# 4

# default is 1
print(effective_n_jobs(None))
# 1


depths = np.zeros(n_samples, order="f")
old_joblib = LooseVersion(joblib.__version__) < LooseVersion('0.12')
check_pickle = False if old_joblib else None

for tree, features in zip(self.estimators_, self.estimators_features_):
X_subset = X[:, features] if subsample_features else X
par_exec = Parallel(n_jobs=n_jobs, **self._parallel_args())
par_results = par_exec(
delayed(get_depths, check_pickle=check_pickle)(
_X=X, trees=self.estimators_[starts[i]: starts[i + 1]],
trees_features=self.estimators_features_[
starts[i]: starts[i + 1]],
_subsample_features=subsample_features)
for i in range(n_jobs))

leaves_index = tree.apply(X_subset)
node_indicator = tree.decision_path(X_subset)
n_samples_leaf = tree.tree_.n_node_samples[leaves_index]
depths = np.sum(par_results, axis=0)

depths += (
np.ravel(node_indicator.sum(axis=1))
+ _average_path_length(n_samples_leaf)
- 1.0
)
scores = 2 ** (-depths / (len(self.estimators_)
* _average_path_length([self.max_samples_])))

scores = 2 ** (
-depths
/ (len(self.estimators_)
* _average_path_length([self.max_samples_]))
)
return scores


Expand Down
5 changes: 5 additions & 0 deletions sklearn/ensemble/tests/test_iforest.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,15 +169,20 @@ def test_iforest_parallel_regression():

ensemble.set_params(n_jobs=1)
y1 = ensemble.predict(X_test)
scores1 = ensemble.score_samples(X_test)
ensemble.set_params(n_jobs=2)
y2 = ensemble.predict(X_test)
scores2 = ensemble.score_samples(X_test)
assert_array_almost_equal(y1, y2)
assert_array_almost_equal(scores1, scores2)

ensemble = IsolationForest(n_jobs=1,
random_state=0).fit(X_train)

y3 = ensemble.predict(X_test)
scores3 = ensemble.score_samples(X_test)
assert_array_almost_equal(y1, y3)
assert_array_almost_equal(scores1, scores3)


def test_iforest_performance():
Expand Down