Skip to content

Commit 9896428

Browse files
author
minjk-bl
committed
Move some functions(ROC, AUC) from Evaluation to Model Info
1 parent 42cfb0b commit 9896428

File tree

1 file changed

+24
-2
lines changed

1 file changed

+24
-2
lines changed

js/m_ml/ModelInfo.js

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ define([
546546
name: 'roc_curve',
547547
label: 'ROC Curve',
548548
import: 'from sklearn import metrics',
549-
code: "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.decision_function(${roc_featureData}))\n\
549+
code: "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.predict_proba(${roc_featureData}))\n\
550550
plt.plot(fpr, tpr, label='ROC Curve')\n\
551551
plt.xlabel('Sensitivity')\n\
552552
plt.ylabel('Specificity')\n\
@@ -561,7 +561,7 @@ plt.show()",
561561
name: 'auc',
562562
label: 'AUC',
563563
import: 'from sklearn import metrics',
564-
code: 'metrics.roc_auc_score(${auc_targetData}, ${model}.decision_function(${auc_featureData}))',
564+
code: 'metrics.roc_auc_score(${auc_targetData}, ${model}.predict_proba(${auc_featureData}))',
565565
description: '',
566566
options: [
567567
{ name: 'auc_targetData', label: 'Target Data', component: ['var_select'], var_type: ['DataFrame', 'Series', 'ndarray', 'list', 'dict'], value: 'y_test' },
@@ -570,6 +570,28 @@ plt.show()",
570570
},
571571
'permutation_importance': defaultInfos['permutation_importance']
572572
}
573+
574+
// use decision_function on ROC, AUC
575+
let decisionFunctionTypes = [
576+
'LogisticRegression', 'SVC', 'GradientBoostingClassifier'
577+
];
578+
if (decisionFunctionTypes.includes(modelType)) {
579+
infos = {
580+
...infos,
581+
'roc_curve': {
582+
...infos['roc_curve'],
583+
code: "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.decision_function(${roc_featureData}))\n\
584+
plt.plot(fpr, tpr, label='ROC Curve')\n\
585+
plt.xlabel('Sensitivity')\n\
586+
plt.ylabel('Specificity')\n\
587+
plt.show()"
588+
},
589+
'auc': {
590+
...infos['auc'],
591+
code: 'metrics.roc_auc_score(${auc_targetData}, ${model}.decision_function(${auc_featureData}))',
592+
}
593+
}
594+
}
573595
break;
574596
case 'Auto ML':
575597
infos = {

0 commit comments

Comments
 (0)