Skip to content

MNT Refactor tree splitter to use memoryviews #23273

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 13, 2022

Conversation

thomasjpfan
Copy link
Member

@thomasjpfan thomasjpfan commented May 4, 2022

This PR refactors the tree splitter to use memoryview and allows Python to manage memory.

Benchmark

Running this benchmark that compares best/random splitter and numpy/sparse input between this PR and main:

Screen Shot 2022-05-04 at 2 22 48 PM

Overall, this PR has the same runtime performance as main for dense input. For sparse input, this PR does a little better.

@thomasjpfan thomasjpfan marked this pull request as draft May 4, 2022 12:48
@thomasjpfan thomasjpfan marked this pull request as ready for review May 4, 2022 12:54
@thomasjpfan thomasjpfan marked this pull request as draft May 4, 2022 13:01
@thomasjpfan thomasjpfan marked this pull request as ready for review May 4, 2022 15:28
@thomasjpfan thomasjpfan marked this pull request as draft May 4, 2022 17:33
@thomasjpfan thomasjpfan marked this pull request as ready for review May 4, 2022 17:57
@thomasjpfan thomasjpfan marked this pull request as draft May 4, 2022 18:08
@thomasjpfan thomasjpfan marked this pull request as ready for review May 4, 2022 18:22
Comment on lines +1055 to +1056
if start_positive < end:
simultaneous_sort(&Xf[start_positive], &samples[start_positive], end - start_positive)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fixes a bug on main where start_positive == end, which can lead to samples[start_positive] being out of bounds.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to write a non-regression test to trigger this case?

Copy link
Member Author

@thomasjpfan thomasjpfan May 9, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few tests that trigger this case causing the CI to fail before. Note these test only fail when compiled with SKLEARN_ENABLE_DEBUG_CYTHON_DIRECTIVES=1, which enables bound checking on memoryviews.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the bug fixes would deserve a changelog entry and ideally a non-regression test.

About the use of typed memoryviews, it looks good to me. I tried to see if the new code would not leak any memory using a for loop with psutil prints and all things good (memory usage is stable after a few iterations, as in in main).

I am surprised that we seem to observe a small but significant speed-up with the MSE criterion. I am not sure what is causing this. Maybe the compiler can optimize things further with the C code generated by typed memory views (e.g. contiguity explicitly declared with the [::1] notation?).

@ogrisel
Copy link
Member

ogrisel commented May 9, 2022

I re-ran the benchmark on my own laptop with the latest version of this PR and I the following similar:

Figure_1

I don't think there is any significant difference between this branch and main.

@glemaitre glemaitre self-requested a review May 13, 2022 12:15
cdef SIZE_t n_features # X.shape[1]
cdef DTYPE_t* feature_values # temp. array holding feature values
cdef DTYPE_t[::1] feature_values # temp. array holding feature values

cdef SIZE_t start # Start position for the current node
cdef SIZE_t end # End position for the current node
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you plan to change as sample_weight in a future PR?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I was referring to the pointer that is 2 lines below sample_weight.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think @glemaitre means:

cdef DOUBLE_t* sample_weight

Yes I plan to do it in the future. sample_weight touches multiple files, so I wanted to do it in another PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok I see: the sample_weight attribute below is still defined as a pointer (cdef DOUBLE_t* sample_weight) and it could also be changed to a memory view.

+1. I am fine for doing this in a later PR and merge this one.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works with me.

Copy link
Member

@glemaitre glemaitre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a small question before merging.

@glemaitre glemaitre merged commit 962c9b4 into scikit-learn:main May 13, 2022
RMeli added a commit to RMeli/scikit-learn that referenced this pull request May 14, 2022
glemaitre added a commit to glemaitre/scikit-learn that referenced this pull request Aug 4, 2022
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
glemaitre added a commit that referenced this pull request Aug 5, 2022
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
sebp added a commit to sebp/scikit-survival that referenced this pull request Oct 17, 2022
Fixes binary incompatability due to
scikit-learn/scikit-learn#23273

==1.1 enforces scikit-learn 1.1.0, but ~=1.1.0
allows for any version up to 1.2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants