Skip to content

[MRG+1] _criterion.pyx cleanup #5278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 22, 2015

Conversation

jmschrei
Copy link
Member

This pull request focuses on cleaning up criterion code, without adding any new functionality. This is the first step to adding caching between splits. That branch had too many changes, so I broke it up to be easier to review. This PR addresses the following:

(1) Eliminates nearly 200 lines of unneeded code (mostly by calling class variables self.n_outputs directly instead of unpacking them)
(2) Standardizes label_count for classification criteria and sum_total for regression into the node_sum (repeat for left and right versions)
(3) Adds yw_sq_sum as the last sufficient statistic needed for regression criteria between splits
(4) Previously, regression child_impurity recalculated sufficient statistics by rescanning the sequence, despite them already being present. I fixed this. However, the speed increase is minimal given how rarely this function is called.

All unit tests pass, and speed is basically the same as before.

ping @arjoly @glouppe @ogrisel

cdef SIZE_t p = 0
cdef SIZE_t k = 0
cdef SIZE_t c = 0
cdef SIZE_t i, p, k, c
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that we want to put each variable on a different line, but in the case of single character iterators it seems a waste of space. I also disagree with setting them to 0, because that implies the 0 means something (like with offset below), when it just gets reset anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some compiler might issue a warning because you have uninitialized variables.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I didn't know that. Do you think this will be an issue? I haven't gotten any warnings on my windows or ubuntu machines.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe those are not activated during compilation. No big deal.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-1 for putting them all in one line. If we go in that direction, then we will start wondering when it is / it is not acceptable to regroup them or not. I dont mind losing a few lines.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For me the rule is that if they are used as iterators or indexes then they can go on the same line, since they are just helper variables. Having the variables on different lines makes the functions look, to me, more complicated than they actually are. However, I don't feel too strongly about it so I can change it back if you disagree.

@jmschrei jmschrei mentioned this pull request Sep 16, 2015
12 tasks
@@ -226,9 +236,9 @@ cdef class ClassificationCriterion(Criterion):
self.weighted_n_right = 0.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those lines could be put in the parent class.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, classification criteria and regression criteria have a different number of parameters in their __cinit__ method, so they can't share one. A possibility is to put a dummy parameter in the regression one, with an unused default value. What do you think?

@arjoly
Copy link
Member

arjoly commented Sep 16, 2015

(1) Eliminates nearly 200 lines of unneeded code (mostly by calling class variables self.n_outputs directly instead of unpacking them)

You remove micro-optimizations. Maybe @glouppe will have some thoughts about that.

(2) Standardizes label_count for classification criteria and sum_total for regression into the node_sum (repeat for left and right versions)

+1

(3) Adds yw_sq_sum as the last sufficient statistic needed for regression criteria between splits
(4) Previously, regression child_impurity recalculated sufficient statistics by rescanning the sequence, despite them already being present. I fixed this. However, the speed increase is minimal given how rarely this function is called.

I would do this in another pr. You might obtain more speed up if this was cached for all possible split (and for all weak estimators in gbrt edit: through criterion re-use).

@arjoly
Copy link
Member

arjoly commented Sep 16, 2015

I would have also expect that you would cached w_y once for all in the mse criterion.

@arjoly
Copy link
Member

arjoly commented Sep 16, 2015

(3) Adds yw_sq_sum as the last sufficient statistic needed for regression criteria between splits
(4) Previously, regression child_impurity recalculated sufficient statistics by rescanning the sequence, despite them already being present. I fixed this. However, the speed increase is minimal given how rarely this function is called.

I would do this in another pr. You might obtain more speed up if this was cached for all possible split (and for all weak estimators in gbrt).

This could be attained if we split the node_reset and the criterion init from in the current init method.

@jmschrei
Copy link
Member Author

Thanks for the review @arjoly!

Regarding (1), I don't know if the micro-optimizations add anything. @ogrisel mentioned that they might if it was a Python object, but as a cython object they don't seem to change anything. My speed tests don't seem to indicate a huge difference.

I don't understand your point about (4). I modified regression criterion to work in the same way as classification criterion, in that is uses already present statistics to calculate the impurity instead of recalculating them from the start.

I would have also expect that you cache w_y once for all in the mse criterion.

I'm sorry, I don't understand?

@arjoly
Copy link
Member

arjoly commented Sep 16, 2015

Hm, I thought that you were making an another optimization. The w_y and w_sq_y are computed many times while they could be cache once for all in the tree construction. Haven't you thought about that previous pr? If not, I can make a quick pull request later.

Regarding (1), I don't know if the micro-optimizations add anything. @ogrisel mentioned that they might if it was a Python object, but as a cython object they don't seem to change anything. My speed tests don't seem to indicate a huge difference.

You need to have a look a cython generated code. There the cython -a .... options which allows to view easily the generated code.

@jmschrei
Copy link
Member Author

I am working towards caching w_y and w_sq_y between splits. However, it was too much for a single pull request, so I broke it into two for easier review. This is the first PR. I will be submitting the second PR soon, after this one is merged.


cdef double yw_sq_sum # Cumulative square in the current node
cdef double yw_sq_sum_left # Cumulative square in the left node
cdef double yw_sq_sum_right # Cumulative square in the right node
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rename this sq_sum_XXX. I originally though that this was the y vector squared.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be re-use also for the gini?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With gini you need to square the count of your labels, so it's not as simple to keep a running sum. If you were to change this value after moving one label over, you'd need to subtract 2n-1 or add 2n+1 where n is your previous observed label count. I attempted this optimization when I did the gini proxy impurity improvement, and it didn't seem to have any speed gain. It was a disappointment.

@jmschrei
Copy link
Member Author

The cython code looks fairly clean.

  /* "sklearn/tree/_criterion.pyx":512
 *         cdef SIZE_t k, c
 * 
 *         for k in range(self.n_outputs):             # <<<<<<<<<<<<<<
 *             for c in range(self.n_classes[k]):
 *                 count_k = node_sum[c]
 */
  __pyx_t_2 = __pyx_v_self->__pyx_base.__pyx_base.n_outputs;
  for (__pyx_t_3 = 0; __pyx_t_3 < __pyx_t_2; __pyx_t_3+=1) {
    __pyx_v_k = __pyx_t_3;

@arjoly
Copy link
Member

arjoly commented Sep 16, 2015

Good. Is it the same for arrays?

edit: great! cython makes the point deferenciation for us.

@arjoly
Copy link
Member

arjoly commented Sep 16, 2015

This is not the case for arrays.

 /* "sklearn/tree/_criterion.pyx":420
 *         if (new_pos - pos) <= (end - new_pos):
 *             for p in range(pos, new_pos):
 *                 i = self.samples[p]             # <<<<<<<<<<<<<<
 * 
 *                 if self.sample_weight != NULL:
 */
      __pyx_v_i = (__pyx_v_self->__pyx_base.samples[__pyx_v_p]);

So I would continue to make it for array pointer.

@jmschrei
Copy link
Member Author

Comments have been addressed. I've reverted to unpacking array pointers, and it looks like it's been fixed in the cython code. Thanks again @arjoly!

@glouppe
Copy link
Contributor

glouppe commented Sep 16, 2015

(I am at conference this week, but this PR and the other for presorting are on my todo list. I'll try to review as soon as I find time)

@arjoly
Copy link
Member

arjoly commented Sep 16, 2015

Checking more, I found that cython doesn't make the pointer deferenciation everywhere.

      /* "sklearn/tree/_criterion.pyx":451
 *                     node_sum_left[label_index] -= w
 * 
 *                 self.weighted_n_left -= w             # <<<<<<<<<<<<<<
 *                 self.weighted_n_right += w
 * 
 */
      __pyx_v_self->__pyx_base.weighted_n_left = (__pyx_v_self->__pyx_base.weighted_n_left - __pyx_v_w);

hm :-/ I don't have strong opinion on this except that we should not decrease performance at all through refactoring.

@@ -40,6 +40,14 @@ cdef class Criterion:
cdef double weighted_n_left # Weighted number of samples in the left node
cdef double weighted_n_right # Weighted number of samples in the right node

cdef double* node_sum_total # An array of the current node value
cdef double* node_sum_left # An array of the left node value
cdef double* node_sum_right # An array of the right node value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is node_ really helpful? I woud remove the prefix. I would also change the comment. From the name of these fields they rather correspond the sum of values.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also describe what is node_sum_total[i]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just one thing though, how are these fields differents from the previous label_count_*? You switched to these names because they are shared for both classification and regression?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not different, it is just a renaming. It's to share them between classification and regression, so that when we do caching between splits we have a unified name to pass through the splitter regardless of criteria.

@glouppe
Copy link
Contributor

glouppe commented Sep 17, 2015

Regarding the removal of unpackings, can you make sure it is not a regression in terms of speed? (The intent was to skip the repeated indirect accesses)


cdef double sq_sum_total # Cumulative square in the current node
cdef double sq_sum_left # Cumulative square in the left node
cdef double sq_sum_right # Cumulative square in the right node
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this are used only in regression, they should be declared in RegressionCriterion.

@jmschrei
Copy link
Member Author

@glouppe thanks for the review!

I've incorporated the comments, including moving the initialization of iterators back onto separate lines.

BRANCH
DecisionTreeRegressor
spambase   0.902 0.077
Gaussian   0.731 8.911
mnist      0.846 22.196
covtypes   0.938 9.837

RandomForestClassifier
spambase   0.944 0.047
Gaussian   0.912 4.038
mnist      0.948 4.937
covtypes   0.944 13.735

ExtraTreesClassifier
spambase   0.944 0.041
Gaussian   0.917 0.696
mnist      0.953 4.73
covtypes   0.939 11.414

GradientBoostingClassifier
spambase   0.595 0.016
Gaussian   0.648 1.275
mnist      0.698 18.907
covtypes   0.639 11.655

DecisionTreeClassifier
spambase   0.906 0.07
Gaussian   0.73 9.718
mnist      0.869 22.907
covtypes   0.942 9.045

GradientBoostingRegressor
boston       2.876 0.007
regression   3.354 1.572
diabetes    45.739 0.005

DecisionTreeRegressor
boston       2.121 0.004
regression   4.363 1.695
diabetes    53.329 0.003

MASTER
DecisionTreeRegressor
spambase   0.902 0.078
Gaussian   0.731 8.813
mnist      0.846 21.587
covtypes   0.938 9.676

RandomForestClassifier
spambase   0.944 0.047
Gaussian   0.912 4.047
mnist      0.948 4.938
covtypes   0.944 14.062

ExtraTreesClassifier
spambase   0.944 0.041
Gaussian   0.917 0.671
mnist      0.953 4.734
covtypes   0.939 11.277

GradientBoostingClassifier
spambase   0.595 0.015
Gaussian   0.648 1.24
mnist      0.698 18.57
covtypes   0.639 11.631

DecisionTreeClassifier
spambase   0.906 0.07
Gaussian   0.73 9.692
mnist      0.869 22.995
covtypes   0.942 9.184

GradientBoostingRegressor
boston 2.876 0.007
regression 3.354 1.584
diabetes 45.739 0.005

DecisionTreeRegressor
boston 2.121 0.004
regression 4.363 1.798
diabetes 53.329 0.003

On the MNIST benchmark:

BRANCH
Classification performance:
===========================
Classifier               train-time   test-time   error-rate
------------------------------------------------------------
ExtraTrees                   53.47s       0.55s       0.0294
RandomForest                 53.37s       0.50s       0.0318
CART                         25.50s       0.01s       0.1219


MASTER
Classification performance:
===========================
Classifier               train-time   test-time   error-rate
------------------------------------------------------------
ExtraTrees                   53.59s       0.56s       0.0294
RandomForest                 53.75s       0.51s       0.0318
CART                         28.14s       0.35s       0.1219

I don't see any regression in performance.

@glouppe
Copy link
Contributor

glouppe commented Sep 17, 2015

Thanks for checking! Then removing the unpackings is indeed a good thing to simplify the code.

impurity_left[0] -= ((sum_left[k] /
self.weighted_n_left) ** 2.0)
impurity_right[0] -= ((sum_right[k] /
self.weighted_n_right) ** 2.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some parentheses are uncessary

@arjoly
Copy link
Member

arjoly commented Sep 18, 2015

(4) Previously, regression child_impurity recalculated sufficient statistics by rescanning the sequence, despite them already being present. I fixed this. However, the speed increase is minimal given how rarely this function is called.

Concerning this, are you sure that you get any speed up? As far as I see, you add new computation to be done in the update step which only re-used whenever children impurity is called.

I am not sure there is gain to compute sq_sum_total the way it's done here. The node_impurity method is only called once and here you compute this statistics at each init call.

@arjoly
Copy link
Member

arjoly commented Sep 18, 2015

(3) Adds yw_sq_sum as the last sufficient statistic needed for regression criteria between splits

What is the current name for this?

@jmschrei
Copy link
Member Author

Using the way you coded it, @arjoly, there is no need to cache sq_sum, which is good. I will remove it in the next revisions. This was a misunderstanding on my part.

I am moving back to the US from France today. There may be delays responding to this until I return.

@@ -356,23 +355,20 @@ cdef class ClassificationCriterion(Criterion):
self.weighted_n_left = 0.0
self.weighted_n_right = self.weighted_n_node_samples

cdef SIZE_t n_outputs = self.n_outputs
cdef SIZE_t* n_classes = self.n_classes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should you unpack this one?

@arjoly
Copy link
Member

arjoly commented Sep 21, 2015

Up to some tiny variations, we have the same performance on my machine.

# master mnist
ExtraTrees                   43.54s       0.75s       0.0294
RandomForest                 42.74s       0.48s       0.0318

# thispr mnist
ExtraTrees                   43.20s       0.57s       0.0294
RandomForest                 42.57s       0.48s       0.0318

+1 for merging whenever the last nitpicks have been addressed.

@jmschrei
Copy link
Member Author

I made a comment which I think got masked by the new commit;

Currently, we have self.weighted_n_left, self.weighted_n_right, and self.weighted_n_node_samples as the weight names. It may be easier to understand (and consistent with our other names) if we make them self.weight_left, self.weight_right, and self.weight_total. What do you think?

@glouppe
Copy link
Contributor

glouppe commented Sep 21, 2015

Currently, we have self.weighted_n_left, self.weighted_n_right, and self.weighted_n_node_samples as the weight names. It may be easier to understand (and consistent with our other names) if we make them self.weight_left, self.weight_right, and self.weight_total. What do you think?

I have to admit that I find weighted_n_node_samples better to understand that what you propose: without even looking at the code, I deduce it is the sum of sample weights for the samples in the node. With total, left and right we dont know what these are referring to. It can be the sum of the sample weigths or the sum of the y values.

@arjoly
Copy link
Member

arjoly commented Sep 21, 2015

I agree with Gilles.

Should we rename the suffix total to node?

@jmschrei
Copy link
Member Author

I disagree on this issue. weight_left seems fairly intuitive to me--however, I may be too close to the code, and so I'll defer to you two here.

I've incorporated all the suggested comments. Thanks for the reviews!

@glouppe
Copy link
Contributor

glouppe commented Sep 22, 2015

@jmschrei Can you fix this last nitpick regarding the comment in _criterion.pxd? Given @arjoly's +1, I will merge after that.

@glouppe glouppe changed the title [MRG] _criterion.pyx cleanup [MRG+1] _criterion.pyx cleanup Sep 22, 2015
@jmschrei
Copy link
Member Author

Thanks for the review, updated! Next step, caching across splits.

@glouppe
Copy link
Contributor

glouppe commented Sep 22, 2015

Thanks for the review, updated! Next step, caching across splits.

Great! Waiting for the green light from CI and I'll hit the merge button.

@jmschrei
Copy link
Member Author

Wait, I missed something on criterion.pxd

@jmschrei
Copy link
Member Author

Okay, sorry =)

glouppe added a commit that referenced this pull request Sep 22, 2015
@glouppe glouppe merged commit 95fe122 into scikit-learn:master Sep 22, 2015
@arjoly
Copy link
Member

arjoly commented Sep 22, 2015

thanks @jmschrei !!!

@GaelVaroquaux
Copy link
Member

GaelVaroquaux commented Sep 22, 2015 via email

@jmschrei
Copy link
Member Author

🎺

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants