Skip to content

MNT Refactor tree to share splitters between dense and sparse data #25306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

thomasjpfan
Copy link
Member

This PR refactors the splitters so the best and random splitters can share the implementation for dense and sparse data. A large portion of the diff is from the de-indenting a class method into a function. Overall, this refactor leads to a reduction of ~200 lines of code.

Implementation Overview

This PR refactors the class method, node_split, into two functions: node_split_{best|random} that takes a fused type: {Dense|Sparse}Splitter. The fused type is used to avoid any overhead from using an inheritance structure, which goes through vtable lookups. I benchmarked an implementation with inheritance and it leads to a ~10% increase in runtime compare to main.

Benchmarks

Here are the ASV benchmarks with SKLBENCH_PROFILE set to regular and large_scale:

asv continuous  -b RandomForestClassifierBenchmark.time_fit main tree_dense_sparse_refactor_v16

This PR does not introduce any performance changes compared to main.

Regular

       before           after         ratio
     [35b5ee65]       [32682f82]
     <main>           <tree_dense_sparse_refactor_v16>
          4.75±0s       4.75±0.01s     1.00  ensemble.RandomForestClassifierBenchmark.time_fit('dense', 1)
       5.66±0.02s       5.70±0.01s     1.01  ensemble.RandomForestClassifierBenchmark.time_fit('sparse', 1)

Large Scale

       before           after         ratio
     [35b5ee65]       [32682f82]
     <main>           <tree_dense_sparse_refactor_v16>
       23.4±0.01s       23.3±0.03s     1.00  ensemble.RandomForestClassifierBenchmark.time_fit('dense', 1)
       28.2±0.02s       27.3±0.03s     0.97  ensemble.RandomForestClassifierBenchmark.time_fit('sparse', 1)

features[n_drawn_constants], features[f_j] = features[f_j], features[n_drawn_constants]

n_drawn_constants += 1
ctypedef fused DataSplitterFused:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would a 3rd party package that inherits the splitting functions be able to theoretically use the node_split_best and node_split_random functions?

E.g. if 1) splitter was an ObliqueSplitter subclass that inherits Splitter and 2) data_splitter was a ObliqueDenseSplitter? My understanding would suggest that data_splitter would make it unusable. Is that right?

Copy link
Member Author

@thomasjpfan thomasjpfan Jan 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The data splitters themselves and node_split_* are not c-importable. A third party ObliqueDenseSplitter would not be able to use the data_splitter or the node_split_{best|random} functions.

This is the current status on main: only the Splitter is cimportable. The extension points provided by node_split_{best|random} for data_splitter is focused on the current codebase. For the most flexibility, I think a third party will need to define their own Splitter.node_split method as it is done in the Oblique Tree PR #22754.

n_constant_features[0] = n_total_constants
return 0

@final
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Related to https://github.com/scikit-learn/scikit-learn/pull/25306/files#r1062801209. Would it make sense to have a base class for "Dense" and "Sparse" splitting that is not final?

And the fused type DataSplitter can come from the base classes rather than the final implementations?

Copy link
Member Author

@thomasjpfan thomasjpfan Jan 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using a base class and it runs ~10% slower compared to main. I think the overhead comes from many virtual table lookups. node_split_* calls into data_splitter._ a lot and node_split_* is called a lot when building the trees.

Copy link
Member Author

@thomasjpfan thomasjpfan Jan 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW there can be a parent class DataSplitter as part of the fused type to allow subclasses, but the DenseSplitter and SparseSplitter still needs to be final. The final is required because it avoids any vtable lookups leading to a ~10% runtime regression.

I prefer this PR to focus on the refactor and not add new third party extension points.


current.pos = partition_end
cdef inline SIZE_t parition_samples(self, double current_threshold) nogil:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cdef inline SIZE_t parition_samples(self, double current_threshold) nogil:
cdef inline SIZE_t partition_samples(self, double current_threshold) nogil:

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to actually name it partition_samples_random since there is also a partition_samples_best?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed this to partition_samples_final. It is called by both the best and random splitter when the split point is finalized.

Copy link
Member

@adam2392 adam2392 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool PR @thomasjpfan!

@glemaitre glemaitre self-requested a review January 6, 2023 15:46
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A first round of nitpicks. I will go more into the implementation now. But given that the test passes and that there is no regression, I think that this is good as a mimalist change for refactoring.

thomasjpfan and others added 3 commits January 8, 2023 19:47
Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@adam2392
Copy link
Member

adam2392 commented Jan 9, 2023

@ogrisel FYI as discussed in monthly dev meeting. I verified that this PR does not affect the pytest performance of the oblique PR, so it would not affect 3rd party use cases of node splitting. I don't foresee this affecting separation of "leaf/split" nodes either, so this refactoring primarily improves maintenance internally and has negibible external effect, which is a great pro.

https://github.com/neurodata/scikit-learn/tree/oliquecheck

The main difference is that the implementation of the underlying node_split does not follow the internal Cython API of sklearn anymore. However, I think that's fine because it still follows the necessary API such that the Tree* code still works.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @thomasjpfan. I think I am fine with the core idea of the refactoring but I think we should clarify the code a bit by:

  • finding a better name for the data splitters (see inline comment below);
  • rewrite their docstrings to make their purpose more explicit and not confuse them subclasses of the Splitter base class;
  • find a way to summarize the information in the PR discussion about what is intentionally not cimport-able by third-party code that build upon the scikit-learn Cython code base.

I don't think it would be easy to make the node_split_* code cimportable by thirdparty libraries easily so I think we can keep the status-quo as it is in that respect.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @thomasjpfan for the PR and thanks @adam2392 for the feedback.

@ogrisel ogrisel merged commit 93533dd into scikit-learn:main Jan 11, 2023
jjerphan pushed a commit to jjerphan/scikit-learn that referenced this pull request Jan 20, 2023
…cikit-learn#25306)



Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Adam Li <adam2392@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
jjerphan pushed a commit to jjerphan/scikit-learn that referenced this pull request Jan 20, 2023
…cikit-learn#25306)



Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
Co-authored-by: Adam Li <adam2392@gmail.com>
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants