@@ -546,7 +546,7 @@ define([
546
546
name : 'roc_curve' ,
547
547
label : 'ROC Curve' ,
548
548
import : 'from sklearn import metrics' ,
549
- code : "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.decision_function (${roc_featureData}))\n\
549
+ code : "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.predict_proba (${roc_featureData}))\n\
550
550
plt.plot(fpr, tpr, label='ROC Curve')\n\
551
551
plt.xlabel('Sensitivity')\n\
552
552
plt.ylabel('Specificity')\n\
@@ -561,7 +561,7 @@ plt.show()",
561
561
name : 'auc' ,
562
562
label : 'AUC' ,
563
563
import : 'from sklearn import metrics' ,
564
- code : 'metrics.roc_auc_score(${auc_targetData}, ${model}.decision_function (${auc_featureData}))' ,
564
+ code : 'metrics.roc_auc_score(${auc_targetData}, ${model}.predict_proba (${auc_featureData}))' ,
565
565
description : '' ,
566
566
options : [
567
567
{ name : 'auc_targetData' , label : 'Target Data' , component : [ 'var_select' ] , var_type : [ 'DataFrame' , 'Series' , 'ndarray' , 'list' , 'dict' ] , value : 'y_test' } ,
@@ -570,6 +570,28 @@ plt.show()",
570
570
} ,
571
571
'permutation_importance' : defaultInfos [ 'permutation_importance' ]
572
572
}
573
+
574
+ // use decision_function on ROC, AUC
575
+ let decisionFunctionTypes = [
576
+ 'LogisticRegression' , 'SVC' , 'GradientBoostingClassifier'
577
+ ] ;
578
+ if ( decisionFunctionTypes . includes ( modelType ) ) {
579
+ infos = {
580
+ ...infos ,
581
+ 'roc_curve' : {
582
+ ...infos [ 'roc_curve' ] ,
583
+ code : "fpr, tpr, thresholds = metrics.roc_curve(${roc_targetData}, ${model}.decision_function(${roc_featureData}))\n\
584
+ plt.plot(fpr, tpr, label='ROC Curve')\n\
585
+ plt.xlabel('Sensitivity')\n\
586
+ plt.ylabel('Specificity')\n\
587
+ plt.show()"
588
+ } ,
589
+ 'auc' : {
590
+ ...infos [ 'auc' ] ,
591
+ code : 'metrics.roc_auc_score(${auc_targetData}, ${model}.decision_function(${auc_featureData}))' ,
592
+ }
593
+ }
594
+ }
573
595
break ;
574
596
case 'Auto ML' :
575
597
infos = {
0 commit comments