1
1
"""
2
2
Testing for the tree module (sklearn.tree).
3
3
"""
4
+ import copy
4
5
import pickle
5
- from copy import copy
6
6
from functools import partial
7
7
from itertools import product
8
8
import struct
@@ -1600,6 +1600,7 @@ def test_no_sparse_y_support():
1600
1600
for name in ALL_TREES :
1601
1601
yield (check_no_sparse_y_support , name )
1602
1602
1603
+
1603
1604
def test_mae ():
1604
1605
# check MAE criterion produces correct results
1605
1606
# on small toy dataset
@@ -1613,21 +1614,26 @@ def test_mae():
1613
1614
assert_array_equal (dt_mae .tree_ .impurity , [7.0 / 2.3 , 3.0 / 0.7 , 4.0 / 1.6 ])
1614
1615
assert_array_equal (dt_mae .tree_ .value .flat , [4.0 , 6.0 , 4.0 ])
1615
1616
1617
+
1616
1618
def test_criterion_copy ():
1617
- # Let's check whether copy of our criterion has same type and properties as original
1619
+ # Let's check whether copy of our criterion has the same type
1620
+ # and properties as original
1618
1621
n_outputs = 3
1619
1622
n_classes = np .arange (3 )
1620
- for _ , typename in CRITERIA_CLF .items ():
1621
- criteria = typename (n_outputs , n_classes )
1622
- typename_ , (n_outputs_ , n_classes_ ), _ = copy (criteria ).__reduce__ ()
1623
- assert_equal (typename , typename_ )
1624
- assert_equal (n_outputs , n_outputs_ )
1625
- assert_array_equal (n_classes , n_classes_ )
1626
-
1627
1623
n_samples = 100
1628
- for _ , typename in CRITERIA_REG .items ():
1629
- criteria = typename (n_outputs , n_samples )
1630
- typename_ , (n_outputs_ , n_samples_ ), _ = copy (criteria ).__reduce__ ()
1631
- assert_equal (typename , typename_ )
1632
- assert_equal (n_outputs , n_outputs_ )
1633
- assert_equal (n_samples , n_samples_ )
1624
+ for copy_func in [copy .copy , copy .deepcopy ]:
1625
+ for _ , typename in CRITERIA_CLF .items ():
1626
+ criteria = typename (n_outputs , n_classes )
1627
+ result = copy_func (criteria ).__reduce__ ()
1628
+ typename_ , (n_outputs_ , n_classes_ ), _ = result
1629
+ assert_equal (typename , typename_ )
1630
+ assert_equal (n_outputs , n_outputs_ )
1631
+ assert_array_equal (n_classes , n_classes_ )
1632
+
1633
+ for _ , typename in CRITERIA_REG .items ():
1634
+ criteria = typename (n_outputs , n_samples )
1635
+ result = copy_func (criteria ).__reduce__ ()
1636
+ typename_ , (n_outputs_ , n_samples_ ), _ = result
1637
+ assert_equal (typename , typename_ )
1638
+ assert_equal (n_outputs , n_outputs_ )
1639
+ assert_equal (n_samples , n_samples_ )
0 commit comments