Skip to content

Commit 12755c6

Browse files
author
minjk-bl
committed
Fix Model options
1 parent 2e369e7 commit 12755c6

File tree

2 files changed

+86
-23
lines changed

2 files changed

+86
-23
lines changed

js/com/component/ModelEditor.js

Lines changed: 79 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -85,26 +85,28 @@ define([
8585
},
8686
'predict': {
8787
name: 'predict',
88-
code: '${model}.predict(${featureData})',
88+
code: '${allocatePredict} = ${model}.predict(${featureData})',
8989
description: 'Predict the closest target data X belongs to.',
9090
options: [
91-
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X_train' }
91+
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X_test' },
92+
{ name: 'allocatePredict', label: 'Allocate to', component: ['input'], placeholder: 'New variable', default: 'pred' }
9293
]
9394
},
9495
'predict_proba': {
9596
name: 'predict_proba',
96-
code: '${model}.predict_proba(${featureData})',
97+
code: '${allocatePredict} = ${model}.predict_proba(${featureData})',
9798
description: 'Predict class probabilities for X.',
9899
options: [
99-
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X_train' }
100+
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X_test' },
101+
{ name: 'allocatePredict', label: 'Allocate to', component: ['input'], placeholder: 'New variable', default: 'pred' }
100102
]
101103
},
102104
'transform': {
103105
name: 'transform',
104106
code: '${allocateTransform} = ${model}.transform(${featureData})',
105107
description: 'Apply dimensionality reduction to X.',
106108
options: [
107-
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X_train' },
109+
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X' },
108110
{ name: 'allocateTransform', label: 'Allocate to', component: ['input'], placeholder: 'New variable' }
109111
]
110112
}
@@ -113,7 +115,23 @@ define([
113115
switch (category) {
114116
case 'Data Preparation':
115117
actions = {
116-
'fit': defaultActions['fit'],
118+
'fit': {
119+
name: 'fit',
120+
code: '${model}.fit(${featureData})',
121+
description: 'Fit Encoder/Scaler to X.',
122+
options: [
123+
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X' }
124+
]
125+
},
126+
'fit_transform': {
127+
name: 'fit_transform',
128+
code: '${allocateTransform} = ${model}.fit_transform(${featureData})',
129+
description: 'Fit Encoder/Scaler to X, then transform X.',
130+
options: [
131+
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X' },
132+
{ name: 'allocateTransform', label: 'Allocate to', component: ['input'], placeholder: 'New variable' }
133+
]
134+
},
117135
'transform': {
118136
...defaultActions['transform'],
119137
description: 'Transform labels to normalized encoding.'
@@ -141,11 +159,31 @@ define([
141159
'predict': defaultActions['predict'],
142160
'predict_proba': defaultActions['predict_proba'],
143161
}
162+
if (['LogisticRegression', 'SVC', 'GradientBoostingClassifier'].includes(modelType)) {
163+
actions = {
164+
...actions,
165+
'decision_function': {
166+
name: 'decision_function',
167+
code: '${allocateScore} = ${model}.decision_function(${featureData})',
168+
description: 'Compute the decision function of X.',
169+
options: [
170+
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X' },
171+
{ name: 'allocateScore', label: 'Allocate to', component: ['input'], placeholder: 'New variable' }
172+
]
173+
}
174+
}
175+
}
144176
break;
145177
case 'Auto ML':
146178
actions = {
147179
'fit': defaultActions['fit'],
148-
'predict': defaultActions['predict'],
180+
'predict': defaultActions['predict']
181+
}
182+
if (modelType == 'TPOTClassifier') {
183+
actions = {
184+
...actions,
185+
'predict_proba': defaultActions['predict_proba']
186+
}
149187
}
150188
break;
151189
case 'Clustering':
@@ -155,10 +193,11 @@ define([
155193
'fit': defaultActions['fit'],
156194
'fit_predict': {
157195
name: 'fit_predict',
158-
code: '${model}.fit_predict(${featureData})',
196+
code: '${allocatePredict} = ${model}.fit_predict(${featureData})',
159197
description: 'Compute clusters from a data or distance matrix and predict labels.',
160198
options: [
161-
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X_train' }
199+
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X' },
200+
{ name: 'allocatePredict', label: 'Allocate to', component: ['input'], placeholder: 'New variable', default: 'pred' }
162201
]
163202
}
164203
}
@@ -167,6 +206,37 @@ define([
167206
actions = {
168207
'fit': defaultActions['fit'],
169208
'predict': defaultActions['predict'],
209+
'fit_predict': {
210+
name: 'fit_predict',
211+
code: '${allocatePredict} = ${model}.fit_predict(${featureData})',
212+
description: 'Compute cluster centers and predict cluster index for each sample.',
213+
options: [
214+
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X' },
215+
{ name: 'allocatePredict', label: 'Allocate to', component: ['input'], placeholder: 'New variable', default: 'pred' }
216+
]
217+
}
218+
}
219+
if (modelType == 'KMeans') {
220+
actions = {
221+
...actions,
222+
'fit_transform': {
223+
name: 'fit_transform',
224+
code: '${model}.fit_transform(${featureData})',
225+
description: 'Compute clustering and transform X to cluster-distance space.',
226+
options: [
227+
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X_train' }
228+
]
229+
},
230+
'transform': {
231+
name: 'transform',
232+
code: '${allocateTransform} = ${model}.transform(${featureData})',
233+
description: 'Transform X to a cluster-distance space.',
234+
options: [
235+
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X' },
236+
{ name: 'allocateTransform', label: 'Allocate to', component: ['input'], placeholder: 'New variable' }
237+
]
238+
}
239+
}
170240
}
171241
break;
172242
case 'Dimension Reduction':
@@ -303,15 +373,6 @@ define([
303373
options: [
304374
{ name: 'allocateCenters', label: 'Allocate to', component: ['input'], placeholder: 'New variable' }
305375
]
306-
},
307-
'transform': {
308-
name: 'transform',
309-
code: '${allocateTransform} = ${model}.transform(${featureData})',
310-
description: 'Transform X to a cluster-distance space.',
311-
options: [
312-
{ name: 'featureData', label: 'Feature Data', component: ['var_select'], var_type: ['DataFrame', 'Series'], default: 'X' },
313-
{ name: 'allocateTransform', label: 'Allocate to', component: ['input'], placeholder: 'New variable' }
314-
]
315376
}
316377
}
317378
}

js/m_ml/evaluation.js

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,13 @@ define([
161161
code.appendLine("# ROC Curve");
162162
code.appendFormatLine("fpr, tpr, thresholds = roc_curve({0}, svc.decision_function({1}}))", predictData, targetData);
163163
code.appendLine("plt.plot(fpr, tpr, label='ROC Curve')");
164-
code.appendLine("plt. xlabel('Sensitivity') ");
165-
code.appendLine("plt. ylabel('Specificity') ")
164+
code.appendLine("plt.xlabel('Sensitivity') ");
165+
code.appendLine("plt.ylabel('Specificity') ")
166166
}
167167
if (auc) {
168-
// FIXME:
168+
code.appendLine("# AUC");
169+
code.appendFormatLine("fpr, tpr, thresholds = roc_curve({0}, svc.decision_function({1}}))", predictData, targetData);
170+
code.appendLine("metrics.auc(fpr, tpr)");
169171
}
170172
}
171173

@@ -221,11 +223,11 @@ define([
221223
code.appendFormatLine("print(f'Silhouette score: {metrics.cluster.silhouette_score({0}, {1})}')", targetData, predictData);
222224
}
223225
if (ari) {
224-
code.appendLine("# ARI"); // FIXME:
226+
code.appendLine("# ARI");
225227
code.appendFormatLine("print(f'ARI: {metrics.cluster.adjusted_rand_score({0}, {1})}')", targetData, predictData);
226228
}
227229
if (nm) {
228-
code.appendLine("# NM"); // FIXME:
230+
code.appendLine("# NM");
229231
code.appendFormatLine("print(f'NM: {metrics.cluster.normalized_mutual_info_score({0}, {1})}')", targetData, predictData);
230232
}
231233
}

0 commit comments

Comments
 (0)