@@ -28,20 +28,21 @@ define([
28
28
class Evaluation extends PopupComponent {
29
29
_init ( ) {
30
30
super . _init ( ) ;
31
+ this . config . importButton = true ;
31
32
this . config . dataview = false ;
32
33
33
34
this . state = {
34
35
modelType : 'rgs' ,
35
36
predictData : 'pred' ,
36
37
targetData : 'y_test' ,
38
+ // regression
39
+ r_squared : true , mae : true , mape : false , rmse : true , scatter_plot : false ,
37
40
// classification
38
41
confusion_matrix : true , report : true ,
39
42
accuracy : false , precision : false , recall : false , f1_score : false ,
40
- // regression
41
- coefficient : false , intercept : false , r_squared : true ,
42
- mae : false , mape : false , rmse : true , scatter_plot : false ,
43
+ roc_curve : false , auc : false ,
43
44
// clustering
44
- sizeOfClusters : true , silhouetteScore : true ,
45
+ silhouetteScore : true , ari : false , nm : false ,
45
46
...this . state
46
47
}
47
48
}
@@ -63,7 +64,39 @@ define([
63
64
64
65
$ ( that . wrapSelector ( '.vp-eval-box' ) ) . hide ( ) ;
65
66
$ ( that . wrapSelector ( '.vp-eval-' + modelType ) ) . show ( ) ;
66
- } )
67
+
68
+ if ( modelType == 'clf' ) {
69
+ // Classification - model selection
70
+ if ( that . checkToShowModel ( ) == true ) {
71
+ $ ( that . wrapSelector ( '.vp-ev-model' ) ) . show ( ) ;
72
+ }
73
+ }
74
+ } ) ;
75
+
76
+ // open model selection show
77
+ $ ( this . wrapSelector ( '.vp-eval-check' ) ) . on ( 'change' , function ( ) {
78
+ let checked = $ ( this ) . prop ( 'checked' ) ;
79
+
80
+ if ( checked ) {
81
+ $ ( that . wrapSelector ( '.vp-ev-model' ) ) . show ( ) ;
82
+ } else {
83
+ if ( that . checkToShowModel ( ) == false ) {
84
+ $ ( that . wrapSelector ( '.vp-ev-model' ) ) . hide ( ) ;
85
+ }
86
+ }
87
+ } ) ;
88
+ }
89
+
90
+ /**
91
+ * Check if anything checked available ( > 0)
92
+ * @returns
93
+ */
94
+ checkToShowModel ( ) {
95
+ let checked = $ ( this . wrapSelector ( '.vp-eval-check:checked' ) ) . length ;
96
+ if ( checked > 0 ) {
97
+ return true ;
98
+ }
99
+ return false ;
67
100
}
68
101
69
102
templateForBody ( ) {
@@ -72,7 +105,7 @@ define([
72
105
$ ( page ) . find ( '.vp-eval-box' ) . hide ( ) ;
73
106
$ ( page ) . find ( '.vp-eval-' + this . state . modelType ) . show ( ) ;
74
107
75
- // varselector TEST:
108
+ // varselector
76
109
let varSelector = new VarSelector2 ( this . wrapSelector ( ) , [ 'DataFrame' , 'list' , 'str' ] ) ;
77
110
varSelector . setComponentID ( 'predictData' ) ;
78
111
varSelector . addClass ( 'vp-state vp-input' ) ;
@@ -85,6 +118,28 @@ define([
85
118
varSelector . setValue ( this . state . targetData ) ;
86
119
$ ( page ) . find ( '#targetData' ) . replaceWith ( varSelector . toTagString ( ) ) ;
87
120
121
+ // model
122
+ // set model list
123
+ let modelOptionTag = new com_String ( ) ;
124
+ vpKernel . getModelList ( 'Classification' ) . then ( function ( resultObj ) {
125
+ let { result } = resultObj ;
126
+ var modelList = JSON . parse ( result ) ;
127
+ modelList && modelList . forEach ( model => {
128
+ let selectFlag = '' ;
129
+ if ( model . varName == that . state . model ) {
130
+ selectFlag = 'selected' ;
131
+ }
132
+ modelOptionTag . appendFormatLine ( '<option value="{0}" data-type="{1}" {2}>{3} ({4})</option>' ,
133
+ model . varName , model . varType , selectFlag , model . varName , model . varType ) ;
134
+ } ) ;
135
+ $ ( page ) . find ( '#model' ) . html ( modelOptionTag . toString ( ) ) ;
136
+ $ ( that . wrapSelector ( '#model' ) ) . html ( modelOptionTag . toString ( ) ) ;
137
+
138
+ if ( ! that . state . model || that . state . model == '' ) {
139
+ that . state . model = $ ( that . wrapSelector ( '#model' ) ) . val ( ) ;
140
+ }
141
+ } ) ;
142
+
88
143
// load state
89
144
let that = this ;
90
145
Object . keys ( this . state ) . forEach ( key => {
@@ -114,8 +169,22 @@ define([
114
169
}
115
170
} ) ;
116
171
172
+ if ( this . state . modelType == 'clf' ) {
173
+ if ( this . state . roc_curve == true || this . state . auc == true ) {
174
+ $ ( page ) . find ( '.vp-ev-model' ) . show ( ) ;
175
+ } else {
176
+ $ ( page ) . find ( '.vp-ev-model' ) . hide ( ) ;
177
+ }
178
+ } else {
179
+ $ ( page ) . find ( '.vp-ev-model' ) . hide ( ) ;
180
+ }
181
+
117
182
return page ;
118
183
}
184
+
185
+ generateImportCode ( ) {
186
+ return 'from sklearn import metrics' ;
187
+ }
119
188
120
189
generateCode ( ) {
121
190
let codeCells = [ ] ;
@@ -124,6 +193,7 @@ define([
124
193
modelType, predictData, targetData,
125
194
// classification
126
195
confusion_matrix, report, accuracy, precision, recall, f1_score, roc_curve, auc,
196
+ model,
127
197
// regression
128
198
coefficient, intercept, r_squared, mae, mape, rmse, scatter_plot,
129
199
// clustering
@@ -173,7 +243,7 @@ define([
173
243
if ( roc_curve ) {
174
244
code = new com_String ( ) ;
175
245
code . appendLine ( "# ROC Curve" ) ;
176
- code . appendFormatLine ( "fpr, tpr, thresholds = roc_curve({0}, svc .decision_function({1}} ))" , predictData , targetData ) ;
246
+ code . appendFormatLine ( "fpr, tpr, thresholds = metrics. roc_curve({0}, {1} .decision_function({2} ))" , predictData , model , targetData ) ;
177
247
code . appendLine ( "plt.plot(fpr, tpr, label='ROC Curve')" ) ;
178
248
code . appendLine ( "plt.xlabel('Sensitivity') " ) ;
179
249
code . append ( "plt.ylabel('Specificity') " )
@@ -182,8 +252,7 @@ define([
182
252
if ( auc ) {
183
253
code = new com_String ( ) ;
184
254
code . appendLine ( "# AUC" ) ;
185
- code . appendFormatLine ( "fpr, tpr, thresholds = roc_curve({0}, svc.decision_function({1}}))" , predictData , targetData ) ;
186
- code . append ( "metrics.auc(fpr, tpr)" ) ;
255
+ code . appendFormat ( "metrics.roc_auc_score({0}, {1}.decision_function({2}))" , predictData , model , targetData ) ;
187
256
codeCells . push ( code . toString ( ) ) ;
188
257
}
189
258
}
@@ -232,7 +301,7 @@ define([
232
301
code . appendLine ( '# Regression plot' ) ;
233
302
code . appendFormatLine ( 'plt.scatter({0}, {1})' , targetData , predictData ) ;
234
303
code . appendFormatLine ( "plt.xlabel('{0}')" , targetData ) ;
235
- code . appendFormatLine ( "plt.ylabel('{1 }')" , predictData ) ;
304
+ code . appendFormatLine ( "plt.ylabel('{0 }')" , predictData ) ;
236
305
code . append ( 'plt.show()' ) ;
237
306
codeCells . push ( code . toString ( ) ) ;
238
307
}
0 commit comments