forked from oegedijk/explainerdashboard
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_linear_model.py
149 lines (104 loc) · 8.98 KB
/
test_linear_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import pytest
import pandas as pd
import numpy as np
import shap
import plotly.graph_objects as go
def test_linreg_explainer_len(precalculated_linear_regression_explainer, testlen):
assert (len(precalculated_linear_regression_explainer) == testlen)
def test_linreg_int_idx(precalculated_linear_regression_explainer, test_names):
assert (precalculated_linear_regression_explainer.get_idx(test_names[0]) == 0)
def test_linreg_random_index(precalculated_linear_regression_explainer):
assert isinstance(precalculated_linear_regression_explainer.random_index(), int)
assert isinstance(precalculated_linear_regression_explainer.random_index(return_str=True), str)
def test_linreg_preds(precalculated_linear_regression_explainer):
assert isinstance(precalculated_linear_regression_explainer.preds, np.ndarray)
def test_linreg_pred_percentiles(precalculated_linear_regression_explainer):
assert isinstance(precalculated_linear_regression_explainer.pred_percentiles(), np.ndarray)
def test_linreg_permutation_importances(precalculated_linear_regression_explainer):
assert isinstance(precalculated_linear_regression_explainer.get_permutation_importances_df(), pd.DataFrame)
def test_linreg_metrics(precalculated_linear_regression_explainer):
assert isinstance(precalculated_linear_regression_explainer.metrics(), dict)
assert isinstance(precalculated_linear_regression_explainer.metrics_descriptions(), dict)
def test_linreg_mean_abs_shap_df(precalculated_linear_regression_explainer):
assert isinstance(precalculated_linear_regression_explainer.get_mean_abs_shap_df(), pd.DataFrame)
def test_linreg_top_interactions(precalculated_linear_regression_explainer):
assert isinstance(precalculated_linear_regression_explainer.top_shap_interactions("Age"), list)
assert isinstance(precalculated_linear_regression_explainer.top_shap_interactions("Age", topx=4), list)
def test_linreg_contrib_df(precalculated_linear_regression_explainer):
assert isinstance(precalculated_linear_regression_explainer.get_contrib_df(0), pd.DataFrame)
assert isinstance(precalculated_linear_regression_explainer.get_contrib_df(0, topx=3), pd.DataFrame)
def test_linreg_shap_base_value(precalculated_linear_regression_explainer):
assert isinstance(precalculated_linear_regression_explainer.shap_base_value(), (np.floating, float))
def test_linreg_shap_values_shape(precalculated_linear_regression_explainer):
assert (precalculated_linear_regression_explainer.get_shap_values_df().shape == (len(precalculated_linear_regression_explainer), len(precalculated_linear_regression_explainer.merged_cols)))
def test_linreg_shap_values(precalculated_linear_regression_explainer):
assert isinstance(precalculated_linear_regression_explainer.get_shap_values_df(), pd.DataFrame)
def test_linreg_mean_abs_shap(precalculated_linear_regression_explainer):
assert isinstance(precalculated_linear_regression_explainer.get_mean_abs_shap_df(), pd.DataFrame)
def test_linreg_calculate_properties(precalculated_linear_regression_explainer):
precalculated_linear_regression_explainer.calculate_properties(include_interactions=False)
def test_linreg_pdp_df(precalculated_linear_regression_explainer):
assert isinstance(precalculated_linear_regression_explainer.pdp_df("Age"), pd.DataFrame)
assert isinstance(precalculated_linear_regression_explainer.pdp_df("Gender"), pd.DataFrame)
assert isinstance(precalculated_linear_regression_explainer.pdp_df("Deck"), pd.DataFrame)
assert isinstance(precalculated_linear_regression_explainer.pdp_df("Age", index=0), pd.DataFrame)
assert isinstance(precalculated_linear_regression_explainer.pdp_df("Gender", index=0), pd.DataFrame)
def test_logreg_preds(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.preds, np.ndarray)
def test_logreg_pred_percentiles(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.pred_percentiles(), np.ndarray)
def test_logreg_columns_ranked_by_shap(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.columns_ranked_by_shap(), list)
def test_logreg_permutation_importances(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.get_permutation_importances_df(), pd.DataFrame)
def test_logreg_metrics(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.metrics(), dict)
assert isinstance(precalculated_logistic_regression_explainer.metrics_descriptions(), dict)
def test_logreg_mean_abs_shap_df(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.get_mean_abs_shap_df(), pd.DataFrame)
def test_logreg_contrib_df(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.get_contrib_df(0), pd.DataFrame)
assert isinstance(precalculated_logistic_regression_explainer.get_contrib_df(0, topx=3), pd.DataFrame)
def test_logreg_shap_base_value(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.shap_base_value(), (np.floating, float))
def test_logreg_shap_values_shape(precalculated_logistic_regression_explainer):
assert (precalculated_logistic_regression_explainer.get_shap_values_df().shape == (len(precalculated_logistic_regression_explainer), len(precalculated_logistic_regression_explainer.merged_cols)))
def test_logreg_shap_values(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.get_shap_values_df(), pd.DataFrame)
def test_logreg_mean_abs_shap(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.get_mean_abs_shap_df(), pd.DataFrame)
def test_logreg_calculate_properties(precalculated_logistic_regression_explainer):
precalculated_logistic_regression_explainer.calculate_properties(include_interactions=False)
def test_logreg_pdp_df(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.pdp_df("Age"), pd.DataFrame)
assert isinstance(precalculated_logistic_regression_explainer.pdp_df("Gender"), pd.DataFrame)
assert isinstance(precalculated_logistic_regression_explainer.pdp_df("Deck"), pd.DataFrame)
assert isinstance(precalculated_logistic_regression_explainer.pdp_df("Age", index=0), pd.DataFrame)
assert isinstance(precalculated_logistic_regression_explainer.pdp_df("Gender", index=0), pd.DataFrame)
def test_logreg_pos_label(precalculated_logistic_regression_explainer):
precalculated_logistic_regression_explainer.pos_label = 1
precalculated_logistic_regression_explainer.pos_label = "Not survived"
assert isinstance(precalculated_logistic_regression_explainer.pos_label, int)
assert isinstance(precalculated_logistic_regression_explainer.pos_label_str, str)
assert (precalculated_logistic_regression_explainer.pos_label == 0)
assert (precalculated_logistic_regression_explainer.pos_label_str == "Not survived")
def test_logreg_pred_probas(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.pred_probas(), np.ndarray)
def test_logreg_metrics(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.metrics(), dict)
assert isinstance(precalculated_logistic_regression_explainer.metrics(cutoff=0.9), dict)
def test_logreg_precision_df(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.get_precision_df(), pd.DataFrame)
assert isinstance(precalculated_logistic_regression_explainer.get_precision_df(multiclass=True), pd.DataFrame)
assert isinstance(precalculated_logistic_regression_explainer.get_precision_df(quantiles=4), pd.DataFrame)
def test_logreg_lift_curve_df(precalculated_logistic_regression_explainer):
assert isinstance(precalculated_logistic_regression_explainer.get_liftcurve_df(), pd.DataFrame)
##### KERNEL TESTS
def test_logistic_regression_kernel_shap_values(logistic_regression_kernel_explainer):
assert isinstance(logistic_regression_kernel_explainer.shap_base_value(), (np.floating, float))
assert (logistic_regression_kernel_explainer.get_shap_values_df().shape == (len(logistic_regression_kernel_explainer), len(logistic_regression_kernel_explainer.merged_cols)))
assert isinstance(logistic_regression_kernel_explainer.get_shap_values_df(), pd.DataFrame)
def test_linear_regression_kernel_shap_values(linear_regression_kernel_explainer):
assert isinstance(linear_regression_kernel_explainer.shap_base_value(), (np.floating, float))
assert (linear_regression_kernel_explainer.get_shap_values_df().shape == (len(linear_regression_kernel_explainer), len(linear_regression_kernel_explainer.merged_cols)))
assert isinstance(linear_regression_kernel_explainer.get_shap_values_df(), pd.DataFrame)