Skip to content

TST Add TODO for global_dtype in sklearn/tree/tests/test_tree.py #22926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
30 changes: 22 additions & 8 deletions sklearn/tree/tests/test_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import pytest
import numpy as np
from numpy.testing import assert_allclose
from scipy.sparse import csc_matrix
from scipy.sparse import csr_matrix
from scipy.sparse import coo_matrix
Expand All @@ -34,6 +33,7 @@
from sklearn.utils._testing import create_memmap_backed_data
from sklearn.utils._testing import ignore_warnings
from sklearn.utils._testing import skip_if_32bit
from sklearn.utils._testing import assert_allclose

from sklearn.utils.estimator_checks import check_sample_weights_invariance
from sklearn.utils.validation import check_random_state
Expand Down Expand Up @@ -260,9 +260,11 @@ def test_weighted_classification_toy():
assert_array_equal(clf.predict(T), true_result, "Failed with {0}".format(name))


# TODO: Tree-based model do not preserve dtype on their fitted attribute
# Add a test using `global_dtype` when this is the case.
@pytest.mark.parametrize("Tree", REG_TREES.values())
@pytest.mark.parametrize("criterion", REG_CRITERIONS)
def test_regression_toy(Tree, criterion):
def test_regression_toy(Tree, criterion, global_dtype):
# Check regression on a toy dataset.
if criterion == "poisson":
# make target positive while not touching the original y and
Expand All @@ -274,13 +276,18 @@ def test_regression_toy(Tree, criterion):
y_train = y
y_test = true_result

y_train = np.array(y_train).astype(global_dtype, copy=False)
y_test = np.array(y_test).astype(global_dtype, copy=False)

reg = Tree(criterion=criterion, random_state=1)
reg.fit(X, y_train)
assert_allclose(reg.predict(T), y_test)
reg.fit(np.array(X).astype(global_dtype), y_train)
y_pred = reg.predict(np.array(T).astype(global_dtype))
assert_allclose(y_pred, y_test)

clf = Tree(criterion=criterion, max_features=1, random_state=1)
clf.fit(X, y_train)
assert_allclose(reg.predict(T), y_test)
clf.fit(np.array(X).astype(global_dtype), y_train)
y_pred = reg.predict(np.array(T).astype(global_dtype))
assert_allclose(y_pred, y_test)


def test_xor():
Expand Down Expand Up @@ -2143,8 +2150,10 @@ def test_poisson_vs_mse():
assert metric_poi < 0.75 * metric_dummy


# TODO: Tree-based model do not preserve dtype on their fitted attribute
# Add a test using `global_dtype` when this is the case.
@pytest.mark.parametrize("criterion", REG_CRITERIONS)
def test_decision_tree_regressor_sample_weight_consistency(criterion):
def test_decision_tree_regressor_sample_weight_consistentcy(criterion, global_dtype):
"""Test that the impact of sample_weight is consistent."""
tree_params = dict(criterion=criterion)
tree = DecisionTreeRegressor(**tree_params, random_state=42)
Expand All @@ -2161,6 +2170,9 @@ def test_decision_tree_regressor_sample_weight_consistency(criterion):
# make it positive in order to work also for poisson criterion
y += np.min(y) + 0.1

X = X.astype(global_dtype, copy=False)
y = y.astype(global_dtype, copy=False)

# check that multiplying sample_weight by 2 is equivalent
# to repeating corresponding samples twice
X2 = np.concatenate([X, X[: n_samples // 2]], axis=0)
Expand All @@ -2178,7 +2190,9 @@ def test_decision_tree_regressor_sample_weight_consistency(criterion):
# Thresholds, tree.tree_.threshold, and values, tree.tree_.value, are not
# exactly the same, but on the training set, those differences do not
# matter and thus predictions are the same.
assert_allclose(tree1.predict(X), tree2.predict(X))
y_pred1 = tree1.predict(X)
y_pred2 = tree2.predict(X)
assert_allclose(y_pred1, y_pred2)


# TODO: Remove in v1.2
Expand Down