-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
ENH Optimize runtime for IsolationForest #23252
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH Optimize runtime for IsolationForest #23252
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the follow up PR!
Since in IsolationForest, we're randomly picking a single feature to split, shall we deprecate the max_features argument and always set it to 1.0 for better runtime ?
I feel like there is still a use case for having each tree see a different subset of features. What do you think @albertcthomas ?
sklearn/ensemble/_bagging.py
Outdated
@@ -135,8 +135,13 @@ def _parallel_build_estimators( | |||
not_indices_mask = ~indices_to_mask(indices, n_samples) | |||
curr_sample_weight[not_indices_mask] = 0 | |||
|
|||
estimator_fit(X[:, features], y, sample_weight=curr_sample_weight) | |||
# Indicator of if indexing is necessary | |||
require_indexing = bootstrap_features or max_features != n_features |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can define this outside the loop. This optimization is also useful for the else:
case a few lines below. Something like:
# outside of loop
requires_feature_indexing = bootstrap_features or max_features != n_features
# in loop
...
X_ = X[:, features] if requires_feature_indexing else X
estimator_fit(X_, y, sample_weight=curr_sample_weight)
else:
X_ = X[indices][:, features] if requires_feature_indexing else X[indices]
estimator_fit(X_, y[indices])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated as suggested :)
Can you update the opening comment to include the benchmark script so it is easier for other reviewers to run the benchmark? (without needing to find it in the other PR) |
Sure! I've updated the benchmark result in the opening comment. |
I get a segfault when I execute #19275 (comment) with this branch with Here is the top of the gdb backtrace:
|
I tried different settings, it seems like the segmentation fault happens when using sparse data and |
There are several places in the code where the data is mutated in place. This is not a problem when working with We need to sort out how to make this thread-safe without introducing costly data-copies. |
Actually I am not sure. The tree models should not do inplace modifications of sparse input data. |
I confirm that making an explicit copy of diff --git a/sklearn/ensemble/_bagging.py b/sklearn/ensemble/_bagging.py
index 9c8faa783..7798dc880 100644
--- a/sklearn/ensemble/_bagging.py
+++ b/sklearn/ensemble/_bagging.py
@@ -137,10 +137,10 @@ def _parallel_build_estimators(
not_indices_mask = ~indices_to_mask(indices, n_samples)
curr_sample_weight[not_indices_mask] = 0
- X_ = X[:, features] if requires_feature_indexing else X
+ X_ = X[:, features] if requires_feature_indexing else X.copy()
estimator_fit(X_, y, sample_weight=curr_sample_weight)
else:
- X_ = X[indices][:, features] if requires_feature_indexing else X[indices]
+ X_ = X[indices][:, features] if requires_feature_indexing else X[indices].copy()
estimator_fit(X_, y[indices])
estimators.append(estimator) However, I still do not understand why it's needed: there should be no inplace modification of X. |
With scikit-learn/sklearn/tree/_classes.py Line 170 in 452ede0
so it will do an inplace operation here that changes the dtype of the data: scikit-learn/sklearn/tree/_tree.pyx Lines 102 to 103 in 452ede0
The quick fix is: diff --git a/sklearn/ensemble/_iforest.py b/sklearn/ensemble/_iforest.py
index 4be74d2873..07d28c93e2 100644
--- a/sklearn/ensemble/_iforest.py
+++ b/sklearn/ensemble/_iforest.py
@@ -254,7 +254,7 @@ class IsolationForest(OutlierMixin, BaseBagging):
self : object
Fitted estimator.
"""
- X = self._validate_data(X, accept_sparse=["csc"])
+ X = self._validate_data(X, accept_sparse=["csc"], dtype=np.float32)
if issparse(X):
# Pre-sort indices to avoid that each individual tree of the
# ensemble sorts the indices. |
I've update the code as suggested by @thomasjpfan , and verifies that it fixes the segmentation fault issue :) |
Could you please add a non-regression test with |
Please also add a new changelog entry targeting scikit-learn 1.2 (unless our 1.1 release manager @jeremiedbb thinks it should be part of 1.1 :P) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I confirm this fixes the problem reported in the linked issue. Here is the output of the script on my machine:
- this PR
$ python ~/code/sanbox/iso.py 4
data size: 731.2 MB
Running IsolationForest with n_jobs=4...
duration: 0.9 s
final model size: 38.0 MB
main
$ python ~/code/sanbox/iso.py 4
data size: 731.2 MB
Running IsolationForest with n_jobs=4...
duration: 10.2 s
final model size: 38.0 MB
Hi @ogrisel , I've added test case and changelog as suggested, meanwhile I think we still need to make a decision on whether we should deprecate the Currently users may set |
I think we can keep |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the update!
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Reference Issues/PRs
Fixes #19275
This is a follow-up PR of #23149.
What does this implement/fix? Explain your changes.
As discussed in the comment, the indexing operation in
_bagging.py: _parallel_build_estimators
is quite expensive, so whenbootstrap_features
is set to False andmax_features_
is equal to number of features, we can skip indexing to get better performance.Benchmark result
Code used for profiling:
Sparse Input:

Before (4.97s)
After (0.241s)

Dense Input:

Before (7.99s)
After (5.19s)

Any other comments?
boostrap_features
is set to False already for IsolationForest, but users can still setmax_features
freely, which will cause a worse runtime when it's not 1.0.Since in IsolationForest, we're randomly picking a single feature to split, shall we deprecate the
max_features
argument and always set it to 1.0 for better runtime ?