Skip to content

FIX Uses log2 in tree building #30557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats_new/upcoming_changes/sklearn.tree/30557.fix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
- Use `log2` instead of `ln` for building trees to maintain behavior of previous
versions. By `Thomas Fan`_
13 changes: 9 additions & 4 deletions sklearn/tree/_partitioner.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ and sparse data stored in a Compressed Sparse Column (CSC) format.
# SPDX-License-Identifier: BSD-3-Clause

from cython cimport final
from libc.math cimport isnan, log
from libc.math cimport isnan, log2
from libc.stdlib cimport qsort
from libc.string cimport memcpy

Expand Down Expand Up @@ -503,8 +503,8 @@ cdef class SparsePartitioner:
# O(n_samples * log(n_indices)) is the running time of binary
# search and O(n_indices) is the running time of index_to_samples
# approach.
if ((1 - self.is_samples_sorted) * n_samples * log(n_samples) +
n_samples * log(n_indices) < EXTRACT_NNZ_SWITCH * n_indices):
if ((1 - self.is_samples_sorted) * n_samples * log2(n_samples) +
n_samples * log2(n_indices) < EXTRACT_NNZ_SWITCH * n_indices):
extract_nnz_binary_search(X_indices, X_data,
indptr_start, indptr_end,
samples, self.start, self.end,
Expand Down Expand Up @@ -702,12 +702,17 @@ cdef inline void shift_missing_values_to_left_if_required(
best.pos += best.n_missing


def _py_sort(float32_t[::1] feature_values, intp_t[::1] samples, intp_t n):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this function to test sort directly.

"""Used for testing sort."""
sort(&feature_values[0], &samples[0], n)


# Sort n-element arrays pointed to by feature_values and samples, simultaneously,
# by the values in feature_values. Algorithm: Introsort (Musser, SP&E, 1997).
cdef inline void sort(float32_t* feature_values, intp_t* samples, intp_t n) noexcept nogil:
if n == 0:
return
cdef intp_t maxd = 2 * <intp_t>log(n)
cdef intp_t maxd = 2 * <intp_t>log2(n)
introsort(feature_values, samples, n, maxd)


Expand Down
23 changes: 23 additions & 0 deletions sklearn/tree/tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
DENSE_SPLITTERS,
SPARSE_SPLITTERS,
)
from sklearn.tree._partitioner import _py_sort
from sklearn.tree._tree import (
NODE_DTYPE,
TREE_LEAF,
Expand Down Expand Up @@ -2814,3 +2815,25 @@ def test_build_pruned_tree_infinite_loop():
ValueError, match="Node has reached a leaf in the original tree"
):
_build_pruned_tree_py(pruned_tree, tree.tree_, leave_in_subtree)


def test_sort_log2_build():
"""Non-regression test for gh-30554.

Using log2 and log in sort correctly sorts feature_values, but the tie breaking is
different which can results in placing samples in a different order.
"""
rng = np.random.default_rng(75)
some = rng.normal(loc=0.0, scale=10.0, size=10).astype(np.float32)
feature_values = np.concatenate([some] * 5)
samples = np.arange(50)
_py_sort(feature_values, samples, 50)
# fmt: off
# no black reformatting for this specific array
expected_samples = [
0, 40, 30, 20, 10, 29, 39, 19, 49, 9, 45, 15, 35, 5, 25, 11, 31,
41, 1, 21, 22, 12, 2, 42, 32, 23, 13, 43, 3, 33, 6, 36, 46, 16,
26, 4, 14, 24, 34, 44, 27, 47, 7, 37, 17, 8, 38, 48, 28, 18
]
# fmt: on
assert_array_equal(samples, expected_samples)
Loading