-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MAINT] Improve extensibility of the Tree/Splitter code #22756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thanks for the PR! According to the Oblique PR, this works only if we extend the scikit-learn/sklearn/tree/_split_record.pxd Lines 24 to 25 in 6051b40
This means the extensions only work for trees internal to scikit-learn. If external libraries want to extend trees with a new split record, they would need to adjust From a software design point of view, I think it is strange to have fields in |
Apologies I should've elaborated. This is able to be circumvented by defining a fusedtype. I didn't add this yet because I wasn't sure if you would prefer the soln I have implemented rn. ctypedef fusedtype GeneralSplitRecord:
ObliqueSplitRecord
SplitRecord
# for Cython class functions that now pass around SplitRecord, just do the following
# and it will work because fusedtype is acceptable in function headers
cdef func(GeneralSplitRecord split, ...) The extension to a split record that also stores the vector of indices and weights are sufficient to encompass all possible oblique trees (here we only implement and consider the one proposed by Breiman). |
If you would like, I can modify the ObliquePR to reflect this is true first. I have verified it locally. |
I think the solution is acceptable if fused types work in the context of making ObliqueTree code smaller. With my understanding of fused types, I think it will make the code more complicated for this specific use case. At a high level, its similar to using C++ templates but also with a class hierarchy. I have not explored fused type deeply, so if there is a way to use fused types that is also maintainable, I am open to it. Here is a code snippet with an alternative approach that may work for your use case: cdef struct SplitRecord:
int i
cdef struct ObliqueRecord:
SplitRecord a
int k
# Lets say this is regular tree method
cdef set_node_regular(SplitRecord * a):
print(a.i)
# Lets say this is oblique tree method
cdef set_node_oblique(SplitRecord * a):
cdef ObliqueRecord b = (<ObliqueRecord*>(a))[0]
print(b.a.i)
print(b.k)
def main():
cdef SplitRecord a
a.i = 10
set_node_regular(&a)
cdef ObliqueRecord b
b.a.i = 3
b.k = 4
set_node_oblique(<SplitRecord*>(&b))
|
Hi @thomasjpfan thanks for the input. This is an interesting idea. The issue is that E.g. I think we would need something like this no? # option 1 (idk if this works tho...)
def main():
cdef SplitRecord a
a.i = 10
set_node_regular(&a)
# in your suggestion, we leave build() as is, which then
# requires us to define the obliquesplitrecord within this function
set_node_oblique(&a)
# or option 2
def main():
cdef SplitRecord a
a.i = 10
cdef ObliqueRecord b
b.a = a
b.k = 4
set_node_regular(&a)
set_node_oblique(<SplitRecord>(&b)) If option 2, then we just have to be okay defining an Alternatively, we could use |
For me to properly evaluate, I would need to see an implementation.
I'm trying really hard to avoid Tempita for this use case. I find Tempita very hard to maintain. |
I'm not sure I follow 100% here because it's getting into some deep c++, but this seems like the most desirable path. If what I spec out is correct, then I think this will be very clean. Thanks for answering any bad questions I have :p.
Say it looks something like this. # this would be the struct
cdef struct ObliqueSplitRecord:
# Pointer for a normal split record see _splitter.pxd
shared_ptr[SplitRecord] *split_record
# the following are only used for oblique trees
vector[DTYPE_t]* proj_vec_weights # weights of the vector
vector[SIZE_t]* proj_vec_indices # indices of the features
# then say node_split would need to be redefined as:
cdef int node_split(self,
double impurity, # Impurity of the node
shared_ptr[SplitRecord] split,
SIZE_t* n_constant_features) nogil except -1:
cdef ObliqueSplitRecord oblique_split
oblique_split.split_record = split
... Then once it's passed from the cdef SIZE_t _add_node(self, SIZE_t parent, bint is_left, bint is_leaf,
shared_ptr[SplitRecord] split_node, double impurity,
SIZE_t n_node_samples,
double weighted_n_node_samples) nogil except -1:
# is this then correct and able to get the original oblique_record?
cdef ObliqueRecord oblique_record = (<ObliqueRecord*>(split_node))[0]
# the oblique_record here needs access to the proj_indices and proj_weights
do_something_with(oblique_record.proj_vec_indices) If this is what you had in mind, then I think it'll look very nice! |
Yup, that is the general idea. We can keep things extra simple by having |
Although, it is unfortunate that the pruner needs to define a split record just to add a node. Maybe a new method |
I think Otherwise I'm having difficulty going from SplitRecord to the ObliqueSplitRecord that contains it. The main issue being, |
Can cdef struct ObliqueSplitRecord:
SplitRecord split_record
# the following are only used for oblique trees
vector[DTYPE_t]* proj_vec_weights # weights of the vector
vector[SIZE_t]* proj_vec_indices # indices of the features |
I believe it can be, but when you use the pointer in a downstream function E.g. # this is called within TreeBuilder
cdef int node_split(self, double impurity,
SplitRecord* split,
SIZE_t* n_constant_features) nogil except -1:
# create an ObliqueSplitRecord
cdef ObliqueSplitRecord oblique_split = (<ObliqueSplitRecord*>(split))[0]
# do some stuff to set the projection vector
oblique_split.proj_vect_indices = func(...)
...
# need to set the `split` pointer to what we have here
split[0] = oblique_split.split_record
# Then within ObliqueTree, you need to be able to set the Node Values
cdef int _set_node_values(self, SplitRecord *split_node, Node *node) nogil except -1:
# Re-create the oblique split record that holds the original
cdef ObliqueSplitRecord oblique_split_node = (<ObliqueSplitRecord*>(split_node))[0]
# this now results in a segfault because there is nothing there
with gil:
print(deref(oblique_split_node.proj_vec_weights)) |
Yes. I rather have a few conversions like this than a |
Just added some notes to my original comment. Lmk if you think I'm missing something there. The issue I'm having is getting access to the |
I opened neurodata#16 to showcase how to work around the SplitRecord issue. (Although I may have broke the underlying algorithm.) As for this PR, I do not think it can be merged as is. Making functions cimportable can be done in a case by case basis. Refactoring trees, needs a proof of concept to make sure those are the right abstractions. In other words, oblique trees needs to demonstrate how it will use these new abstractions. After that, we can open a small PR to refactor the trees. |
Reference Issues/PRs
Fixes: #22753
What does this implement/fix? Explain your changes.
_splitter.pxd
definition fileTree
class to enable easier extensionsAny other comments?
N/a