Skip to content

Commit 31b28a1

Browse files
author
minjk-bl
committed
Add display() to evaluation
1 parent 6db2412 commit 31b28a1

File tree

1 file changed

+25
-9
lines changed

1 file changed

+25
-9
lines changed

visualpython/js/m_ml/evaluation.js

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ define([
199199
sizeOfClusters, silhouetteScore, ari, nmi,
200200
clusteredIndex, featureData2, targetData2
201201
} = this.state;
202+
let needDisplay = false;
202203

203204
//====================================================================
204205
// Classfication
@@ -207,7 +208,8 @@ define([
207208
if (confusion_matrix) {
208209
code = new com_String();
209210
code.appendLine("# Confusion Matrix");
210-
code.appendFormat('pd.crosstab({0}, {1}, margins=True)', targetData, predictData);
211+
code.appendFormat('display(pd.crosstab({0}, {1}, margins=True))', targetData, predictData);
212+
needDisplay = true;
211213
codeCells.push(code.toString());
212214
}
213215
if (report) {
@@ -219,25 +221,29 @@ define([
219221
if (accuracy) {
220222
code = new com_String();
221223
code.appendLine("# Accuracy");
222-
code.appendFormat('metrics.accuracy_score({0}, {1})', targetData, predictData);
224+
code.appendFormat('display(metrics.accuracy_score({0}, {1}))', targetData, predictData);
225+
needDisplay = true;
223226
codeCells.push(code.toString());
224227
}
225228
if (precision) {
226229
code = new com_String();
227230
code.appendLine("# Precision");
228-
code.appendFormat("metrics.precision_score({0}, {1}, average='weighted')", targetData, predictData);
231+
code.appendFormat("display(metrics.precision_score({0}, {1}, average='weighted'))", targetData, predictData);
232+
needDisplay = true;
229233
codeCells.push(code.toString());
230234
}
231235
if (recall) {
232236
code = new com_String();
233237
code.appendLine("# Recall");
234-
code.appendFormat("metrics.recall_score({0}, {1}, average='weighted')", targetData, predictData);
238+
code.appendFormat("display(metrics.recall_score({0}, {1}, average='weighted'))", targetData, predictData);
239+
needDisplay = true;
235240
codeCells.push(code.toString());
236241
}
237242
if (f1_score) {
238243
code = new com_String();
239244
code.appendLine("# F1-score");
240-
code.appendFormat("metrics.f1_score({0}, {1}, average='weighted')", targetData, predictData);
245+
code.appendFormat("display(metrics.f1_score({0}, {1}, average='weighted'))", targetData, predictData);
246+
needDisplay = true;
241247
codeCells.push(code.toString());
242248
}
243249
// if (roc_curve) {
@@ -272,13 +278,15 @@ define([
272278
if (r_squared) {
273279
code = new com_String();
274280
code.appendLine("# R square");
275-
code.appendFormat('metrics.r2_score({0}, {1})', targetData, predictData);
281+
code.appendFormat('display(metrics.r2_score({0}, {1}))', targetData, predictData);
282+
needDisplay = true;
276283
codeCells.push(code.toString());
277284
}
278285
if (mae) {
279286
code = new com_String();
280287
code.appendLine("# MAE(Mean Absolute Error)");
281-
code.appendFormat('metrics.mean_absolute_error({0}, {1})', targetData, predictData);
288+
code.appendFormat('display(metrics.mean_absolute_error({0}, {1}))', targetData, predictData);
289+
needDisplay = true;
282290
codeCells.push(code.toString());
283291
}
284292
if (mape) {
@@ -287,13 +295,15 @@ define([
287295
code.appendLine('def MAPE(y_test, y_pred):');
288296
code.appendLine(' return np.mean(np.abs((y_test - pred) / y_test)) * 100');
289297
code.appendLine();
290-
code.appendFormat('MAPE({0}, {1})', targetData, predictData);
298+
code.appendFormat('display(MAPE({0}, {1}))', targetData, predictData);
299+
needDisplay = true;
291300
codeCells.push(code.toString());
292301
}
293302
if (rmse) {
294303
code = new com_String();
295304
code.appendLine("# RMSE(Root Mean Squared Error)");
296-
code.appendFormat('metrics.mean_squared_error({0}, {1})**0.5', targetData, predictData);
305+
code.appendFormat('display(metrics.mean_squared_error({0}, {1})**0.5)', targetData, predictData);
306+
needDisplay = true;
297307
codeCells.push(code.toString());
298308
}
299309
if (scatter_plot) {
@@ -333,6 +343,12 @@ define([
333343
codeCells.push(code.toString());
334344
}
335345
}
346+
if (needDisplay === true) {
347+
codeCells = [
348+
"from IPython.display import display",
349+
...codeCells
350+
];
351+
}
336352
// return as seperated cells
337353
return codeCells;
338354
}

0 commit comments

Comments
 (0)