-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[MRG+2] Fix #6420 Cloning decision tree estimators breaks criterion objects #7680
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
""" | ||
Testing for the tree module (sklearn.tree). | ||
""" | ||
import copy | ||
import pickle | ||
from functools import partial | ||
from itertools import product | ||
|
@@ -42,12 +43,14 @@ | |
|
||
from sklearn import tree | ||
from sklearn.tree._tree import TREE_LEAF | ||
from sklearn.tree.tree import CRITERIA_CLF | ||
from sklearn.tree.tree import CRITERIA_REG | ||
from sklearn import datasets | ||
|
||
from sklearn.utils import compute_sample_weight | ||
|
||
CLF_CRITERIONS = ("gini", "entropy") | ||
REG_CRITERIONS = ("mse", "mae") | ||
REG_CRITERIONS = ("mse", "mae", "friedman_mse") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh boy wasn't this tested at all ;( |
||
|
||
CLF_TREES = { | ||
"DecisionTreeClassifier": DecisionTreeClassifier, | ||
|
@@ -1597,6 +1600,7 @@ def test_no_sparse_y_support(): | |
for name in ALL_TREES: | ||
yield (check_no_sparse_y_support, name) | ||
|
||
|
||
def test_mae(): | ||
# check MAE criterion produces correct results | ||
# on small toy dataset | ||
|
@@ -1609,3 +1613,30 @@ def test_mae(): | |
dt_mae.fit([[3],[5],[3],[8],[5]],[6,7,3,4,3], [0.6,0.3,0.1,1.0,0.3]) | ||
assert_array_equal(dt_mae.tree_.impurity, [7.0/2.3, 3.0/0.7, 4.0/1.6]) | ||
assert_array_equal(dt_mae.tree_.value.flat, [4.0, 6.0, 4.0]) | ||
|
||
|
||
def test_criterion_copy(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PEP8: extra blank line, please |
||
# Let's check whether copy of our criterion has the same type | ||
# and properties as original | ||
n_outputs = 3 | ||
n_classes = np.arange(3, dtype=np.intp) | ||
n_samples = 100 | ||
|
||
def _pickle_copy(obj): | ||
return pickle.loads(pickle.dumps(obj)) | ||
for copy_func in [copy.copy, copy.deepcopy, _pickle_copy]: | ||
for _, typename in CRITERIA_CLF.items(): | ||
criteria = typename(n_outputs, n_classes) | ||
result = copy_func(criteria).__reduce__() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rather could you pickle, load all the Criteria? Currently this fails in master... import pickle
from sklearn.tree._criterion import FriedmanMSE
fmse = FriedmanMSE(1, 10)
pickle.loads(pickle.dumps(fmse)) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added pickling to this list. |
||
typename_, (n_outputs_, n_classes_), _ = result | ||
assert_equal(typename, typename_) | ||
assert_equal(n_outputs, n_outputs_) | ||
assert_array_equal(n_classes, n_classes_) | ||
|
||
for _, typename in CRITERIA_REG.items(): | ||
criteria = typename(n_outputs, n_samples) | ||
result = copy_func(criteria).__reduce__() | ||
typename_, (n_outputs_, n_samples_), _ = result | ||
assert_equal(typename, typename_) | ||
assert_equal(n_outputs, n_outputs_) | ||
assert_equal(n_samples, n_samples_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jnothman Thanks for suggesting this change. Indeed this is better.
@olologin Could you also remove the copy from this line and update the comment there to note that the function returns a copy and does not refer to the passed pointers...
Additionally it would be nice to change the docstring of this helper to note that we return a copy of numpy array and not not simply 'Encapsulate data' as it is mentioned currently...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.