Skip to content

Conversation

MaxwellLZH
Copy link
Contributor

@MaxwellLZH MaxwellLZH commented Apr 30, 2022

Reference Issues/PRs

Fixes #19275
This is a follow-up PR of #23149.

What does this implement/fix? Explain your changes.

As discussed in the comment, the indexing operation in _bagging.py: _parallel_build_estimators is quite expensive, so when bootstrap_features is set to False and max_features_ is equal to number of features, we can skip indexing to get better performance.

Benchmark result

Code used for profiling:

from sklearn.datasets import make_classification
from scipy.sparse import csc_matrix, csr_matrix
from sklearn.ensemble import IsolationForest

X, y = make_classification(n_samples=50000, n_features=1000)
X = csc_matrix(X)
X.sort_indices()
IsolationForest(n_estimators=10, max_samples=256, n_jobs=1).fit(X)

Sparse Input:
Before (4.97s)
image

After (0.241s)
image

Dense Input:
Before (7.99s)
image

After (5.19s)
image

Any other comments?

boostrap_features is set to False already for IsolationForest, but users can still set max_features freely, which will cause a worse runtime when it's not 1.0.

Since in IsolationForest, we're randomly picking a single feature to split, shall we deprecate the max_features argument and always set it to 1.0 for better runtime ?

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the follow up PR!

Since in IsolationForest, we're randomly picking a single feature to split, shall we deprecate the max_features argument and always set it to 1.0 for better runtime ?

I feel like there is still a use case for having each tree see a different subset of features. What do you think @albertcthomas ?

@@ -135,8 +135,13 @@ def _parallel_build_estimators(
not_indices_mask = ~indices_to_mask(indices, n_samples)
curr_sample_weight[not_indices_mask] = 0

estimator_fit(X[:, features], y, sample_weight=curr_sample_weight)
# Indicator of if indexing is necessary
require_indexing = bootstrap_features or max_features != n_features
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can define this outside the loop. This optimization is also useful for the else: case a few lines below. Something like:

# outside of loop
requires_feature_indexing = bootstrap_features or max_features != n_features

# in loop
    ...
    X_ = X[:, features] if requires_feature_indexing else X
    estimator_fit(X_, y, sample_weight=curr_sample_weight)
else:
    X_ = X[indices][:, features] if requires_feature_indexing else X[indices]
    estimator_fit(X_, y[indices])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated as suggested :)

@thomasjpfan
Copy link
Member

Can you update the opening comment to include the benchmark script so it is easier for other reviewers to run the benchmark? (without needing to find it in the other PR)

@MaxwellLZH
Copy link
Contributor Author

Can you update the opening comment to include the benchmark script so it is easier for other reviewers to run the benchmark? (without needing to find it in the other PR)

Sure! I've updated the benchmark result in the opening comment.

@ogrisel
Copy link
Member

ogrisel commented May 2, 2022

I get a segfault when I execute #19275 (comment) with this branch with n_jobs=4.

Here is the top of the gdb backtrace:

#0  0x00007fffb4b2b6db in __pyx_f_7sklearn_4tree_9_splitter_18BaseSparseSplitter_extract_nnz () from /home/ogrisel/code/scikit-learn/sklearn/tree/_splitter.cpython-310-x86_64-linux-gnu.so
#1  0x00007fffb4b33722 in __pyx_f_7sklearn_4tree_9_splitter_20RandomSparseSplitter_node_split () from /home/ogrisel/code/scikit-learn/sklearn/tree/_splitter.cpython-310-x86_64-linux-gnu.so
#2  0x00007fffb4b01e4a in __pyx_f_7sklearn_4tree_5_tree_21DepthFirstTreeBuilder_build(__pyx_obj_7sklearn_4tree_5_tree_DepthFirstTreeBuilder*, __pyx_obj_7sklearn_4tree_5_tree_Tree*, _object*, tagPyArrayObject_fields*, int, __pyx_opt_args_7sklearn_4tree_5_tree_21DepthFirstTreeBuilder_build*) () from /home/ogrisel/code/scikit-learn/sklearn/tree/_tree.cpython-310-x86_64-linux-gnu.so
#3  0x00007fffb4ac4970 in __pyx_pw_7sklearn_4tree_5_tree_21DepthFirstTreeBuilder_3build(_object*, _object*, _object*) () from /home/ogrisel/code/scikit-learn/sklearn/tree/_tree.cpython-310-x86_64-linux-gnu.so
#4  0x00005555556997f6 in method_vectorcall_VARARGS_KEYWORDS (func=<optimized out>, args=<optimized out>, nargsf=<optimized out>, kwnames=<optimized out>)
    at /usr/local/src/conda/python-3.10.4/Objects/descrobject.c:344
#5  0x0000555555687f2f in _PyObject_VectorcallTstate (kwnames=0x0, nargsf=<optimized out>, args=0x7fffa4001c40, callable=0x7fffb4c64400, tstate=0x55555673f3d0)
    at /usr/local/src/conda/python-3.10.4/Include/cpython/abstract.h:114
#6  PyObject_Vectorcall (kwnames=0x0, nargsf=<optimized out>, args=0x7fffa4001c40, callable=0x7fffb4c64400) at /usr/local/src/conda/python-3.10.4/Include/cpython/abstract.h:123
#7  call_function (kwnames=0x0, oparg=<optimized out>, pp_stack=<synthetic pointer>, trace_info=0x7fffb2ba1560, tstate=<optimized out>) at /usr/local/src/conda/python-3.10.4/Python/ceval.c:5867
#8  _PyEval_EvalFrameDefault (tstate=<optimized out>, f=<optimized out>, throwflag=<optimized out>) at /usr/local/src/conda/python-3.10.4/Python/ceval.c:4198
#9  0x00005555556a56ac in _PyEval_EvalFrame (throwflag=0, f=0x7fffa4001a10, tstate=0x55555673f3d0) at /usr/local/src/conda/python-3.10.4/Include/internal/pycore_ceval.h:46
#10 _PyEval_Vector (kwnames=<optimized out>, argcount=<optimized out>, args=0x7fffb3afeed0, locals=0x0, con=0x7fffb4c46690, tstate=0x55555673f3d0) at /usr/local/src/conda/python-3.10.4/Python/ceval.c:5065
#11 _PyFunction_Vectorcall (kwnames=<optimized out>, nargsf=<optimized out>, stack=0x7fffb3afeed0, func=0x7fffb4c46680) at /usr/local/src/conda/python-3.10.4/Objects/call.c:342
#12 _PyObject_VectorcallTstate (kwnames=<optimized out>, nargsf=<optimized out>, args=0x7fffb3afeed0, callable=0x7fffb4c46680, tstate=0x55555673f3d0)
    at /usr/local/src/conda/python-3.10.4/Include/cpython/abstract.h:114

@MaxwellLZH
Copy link
Contributor Author

MaxwellLZH commented May 2, 2022

I tried different settings, it seems like the segmentation fault happens when using sparse data and n_jobs > 1.

@ogrisel
Copy link
Member

ogrisel commented May 2, 2022

There are several places in the code where the data is mutated in place. This is not a problem when working with X[:, features] because this forces a copy each time. But this becomes a problem in your branch because then X is shared across concurrent thread and therefore the tree-splitting code is no longer thread-safe.

We need to sort out how to make this thread-safe without introducing costly data-copies.

@ogrisel
Copy link
Member

ogrisel commented May 2, 2022

Actually I am not sure. The tree models should not do inplace modifications of sparse input data.

@ogrisel
Copy link
Member

ogrisel commented May 2, 2022

I confirm that making an explicit copy of X fixes the problem:

diff --git a/sklearn/ensemble/_bagging.py b/sklearn/ensemble/_bagging.py
index 9c8faa783..7798dc880 100644
--- a/sklearn/ensemble/_bagging.py
+++ b/sklearn/ensemble/_bagging.py
@@ -137,10 +137,10 @@ def _parallel_build_estimators(
                 not_indices_mask = ~indices_to_mask(indices, n_samples)
                 curr_sample_weight[not_indices_mask] = 0
 
-            X_ = X[:, features] if requires_feature_indexing else X
+            X_ = X[:, features] if requires_feature_indexing else X.copy()
             estimator_fit(X_, y, sample_weight=curr_sample_weight)
         else:
-            X_ = X[indices][:, features] if requires_feature_indexing else X[indices]
+            X_ = X[indices][:, features] if requires_feature_indexing else X[indices].copy()
             estimator_fit(X_, y[indices])
 
         estimators.append(estimator)

However, I still do not understand why it's needed: there should be no inplace modification of X.

@thomasjpfan
Copy link
Member

With check_input=False, the tree code does not get to run:

check_X_params = dict(dtype=DTYPE, accept_sparse="csc")

so it will do an inplace operation here that changes the dtype of the data:

if X.data.dtype != DTYPE:
X.data = np.ascontiguousarray(X.data, dtype=DTYPE)

The quick fix is:

diff --git a/sklearn/ensemble/_iforest.py b/sklearn/ensemble/_iforest.py
index 4be74d2873..07d28c93e2 100644
--- a/sklearn/ensemble/_iforest.py
+++ b/sklearn/ensemble/_iforest.py
@@ -254,7 +254,7 @@ class IsolationForest(OutlierMixin, BaseBagging):
         self : object
             Fitted estimator.
         """
-        X = self._validate_data(X, accept_sparse=["csc"])
+        X = self._validate_data(X, accept_sparse=["csc"], dtype=np.float32)
         if issparse(X):
             # Pre-sort indices to avoid that each individual tree of the
             # ensemble sorts the indices.

@MaxwellLZH
Copy link
Contributor Author

I've update the code as suggested by @thomasjpfan , and verifies that it fixes the segmentation fault issue :)

@ogrisel
Copy link
Member

ogrisel commented May 3, 2022

Could you please add a non-regression test with n_jobs=2?

@ogrisel
Copy link
Member

ogrisel commented May 3, 2022

Please also add a new changelog entry targeting scikit-learn 1.2 (unless our 1.1 release manager @jeremiedbb thinks it should be part of 1.1 :P)

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirm this fixes the problem reported in the linked issue. Here is the output of the script on my machine:

  • this PR
$ python ~/code/sanbox/iso.py 4
data size: 731.2 MB
Running IsolationForest with n_jobs=4...
duration: 0.9 s
final model size: 38.0 MB

pr

  • main
$ python ~/code/sanbox/iso.py 4
data size: 731.2 MB
Running IsolationForest with n_jobs=4...
duration: 10.2 s
final model size: 38.0 MB

main

@MaxwellLZH
Copy link
Contributor Author

Hi @ogrisel , I've added test case and changelog as suggested, meanwhile I think we still need to make a decision on whether we should deprecate the max_features argument and set it to 1.0 at all times before merging this PR.

Currently users may set max_features to a smaller value and observe a much slower runtime, which could be quite confusing :(

@ogrisel
Copy link
Member

ogrisel commented May 3, 2022

I think we can keep max_features. Maybe we can just add a note to explain that it's much slower to enable feature subsampling on sparse training data for this model.

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the update!

MaxwellLZH and others added 6 commits May 4, 2022 16:06
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Co-authored-by: Thomas J. Fan <thomasjpfan@gmail.com>
Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@thomasjpfan thomasjpfan merged commit abbeacc into scikit-learn:main May 4, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

IsolationForest extremely slow with large number of columns having discrete values
3 participants